diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 75b4dd807f..a345b0e18f 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -26,6 +26,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DOCUMENT = "online_document" LOCAL_FILE = "local_file" WEBSITE_CRAWL = "website_crawl" + ONLINE_DRIVER = "online_driver" @classmethod def value_of(cls, value: str) -> "DatasourceProviderType": @@ -303,3 +304,57 @@ class WebsiteCrawlMessage(BaseModel): class DatasourceMessage(ToolInvokeMessage): pass + + +######################### +# Online driver file +######################### + + +class OnlineDriverFile(BaseModel): + """ + Online driver file + """ + + key: str = Field(..., description="The key of the file") + size: int = Field(..., description="The size of the file") + + +class OnlineDriverFileBucket(BaseModel): + """ + Online driver file bucket + """ + + bucket: Optional[str] = Field(None, description="The bucket of the file") + files: list[OnlineDriverFile] = Field(..., description="The files of the bucket") + is_truncated: bool = Field(False, description="Whether the bucket has more files") + + +class OnlineDriverBrowseFilesRequest(BaseModel): + """ + Get online driver file list request + """ + + prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'") + bucket: Optional[str] = Field(None, description="Storage bucket name") + max_keys: int = Field(20, description="Maximum number of files to return") + start_after: Optional[str] = Field( + None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'" + ) + + +class OnlineDriverBrowseFilesResponse(BaseModel): + """ + Get online driver file list response + """ + + result: list[OnlineDriverFileBucket] = Field(..., description="The bucket of the files") + + +class OnlineDriverDownloadFileRequest(BaseModel): + """ + Get online driver file + """ + + key: str = Field(..., description="The name of the file") + bucket: Optional[str] = Field(None, description="The name of the bucket") diff --git a/api/core/datasource/online_driver/online_driver_plugin.py b/api/core/datasource/online_driver/online_driver_plugin.py new file mode 100644 index 0000000000..60322457ac --- /dev/null +++ b/api/core/datasource/online_driver/online_driver_plugin.py @@ -0,0 +1,73 @@ +from collections.abc import Generator + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceMessage, + DatasourceProviderType, + OnlineDriverBrowseFilesRequest, + OnlineDriverBrowseFilesResponse, + OnlineDriverDownloadFileRequest, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +class OnlineDriverDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def online_driver_browse_files( + self, + user_id: str, + request: OnlineDriverBrowseFilesRequest, + provider_type: str, + ) -> Generator[OnlineDriverBrowseFilesResponse, None, None]: + manager = PluginDatasourceManager() + + return manager.online_driver_browse_files( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def online_driver_download_file( + self, + user_id: str, + request: OnlineDriverDownloadFileRequest, + provider_type: str, + ) -> Generator[DatasourceMessage, None, None]: + manager = PluginDatasourceManager() + + return manager.online_driver_download_file( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.ONLINE_DRIVER diff --git a/api/core/datasource/online_driver/online_driver_provider.py b/api/core/datasource/online_driver/online_driver_provider.py new file mode 100644 index 0000000000..edceeecd00 --- /dev/null +++ b/api/core/datasource/online_driver/online_driver_provider.py @@ -0,0 +1,48 @@ +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_driver.online_driver_plugin import OnlineDriverDatasourcePlugin + + +class OnlineDriverDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity, tenant_id) + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.ONLINE_DRIVER + + def get_datasource(self, datasource_name: str) -> OnlineDriverDatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return OnlineDriverDatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index a8e98d2c1a..f38ea0555f 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -5,6 +5,9 @@ from core.datasource.entities.datasource_entities import ( DatasourceMessage, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, + OnlineDriverBrowseFilesRequest, + OnlineDriverBrowseFilesResponse, + OnlineDriverDownloadFileRequest, WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -191,6 +194,78 @@ class PluginDatasourceManager(BasePluginClient): ) yield from response + def online_driver_browse_files( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + request: OnlineDriverBrowseFilesRequest, + provider_type: str, + ) -> Generator[OnlineDriverBrowseFilesResponse, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/online_driver_browse_files", + OnlineDriverBrowseFilesResponse, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "request": request.model_dump(), + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + yield from response + + def online_driver_download_file( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + request: OnlineDriverDownloadFileRequest, + provider_type: str, + ) -> Generator[DatasourceMessage, None, None]: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/online_driver_download_file", + DatasourceMessage, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "request": request.model_dump(), + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + yield from response + def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] ) -> bool: