mirror of
https://github.com/langgenius/dify.git
synced 2026-02-27 20:05:09 +00:00
364 lines
16 KiB
Python
364 lines
16 KiB
Python
import logging
|
|
from collections.abc import Generator
|
|
from threading import Lock
|
|
from typing import Any, cast
|
|
|
|
from sqlalchemy import select
|
|
|
|
import contexts
|
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
|
from core.datasource.entities.datasource_entities import (
|
|
DatasourceMessage,
|
|
DatasourceProviderType,
|
|
GetOnlineDocumentPageContentRequest,
|
|
OnlineDriveDownloadFileRequest,
|
|
)
|
|
from core.datasource.errors import DatasourceProviderNotFoundError
|
|
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
|
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
|
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
|
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
|
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
|
from core.db.session_factory import session_factory
|
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
|
from core.workflow.file import File
|
|
from core.workflow.file.enums import FileTransferMethod, FileType
|
|
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
|
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
|
from factories import file_factory
|
|
from models.model import UploadFile
|
|
from models.tools import ToolFile
|
|
from services.datasource_provider_service import DatasourceProviderService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatasourceManager:
|
|
@classmethod
|
|
def get_datasource_plugin_provider(
|
|
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
|
) -> DatasourcePluginProviderController:
|
|
"""
|
|
get the datasource plugin provider
|
|
"""
|
|
# check if context is set
|
|
try:
|
|
contexts.datasource_plugin_providers.get()
|
|
except LookupError:
|
|
contexts.datasource_plugin_providers.set({})
|
|
contexts.datasource_plugin_providers_lock.set(Lock())
|
|
|
|
with contexts.datasource_plugin_providers_lock.get():
|
|
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
|
|
if provider_id in datasource_plugin_providers:
|
|
return datasource_plugin_providers[provider_id]
|
|
|
|
manager = PluginDatasourceManager()
|
|
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
|
if not provider_entity:
|
|
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
|
controller: DatasourcePluginProviderController | None = None
|
|
match datasource_type:
|
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
|
controller = OnlineDocumentDatasourcePluginProviderController(
|
|
entity=provider_entity.declaration,
|
|
plugin_id=provider_entity.plugin_id,
|
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
|
tenant_id=tenant_id,
|
|
)
|
|
case DatasourceProviderType.ONLINE_DRIVE:
|
|
controller = OnlineDriveDatasourcePluginProviderController(
|
|
entity=provider_entity.declaration,
|
|
plugin_id=provider_entity.plugin_id,
|
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
|
tenant_id=tenant_id,
|
|
)
|
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
|
controller = WebsiteCrawlDatasourcePluginProviderController(
|
|
entity=provider_entity.declaration,
|
|
plugin_id=provider_entity.plugin_id,
|
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
|
tenant_id=tenant_id,
|
|
)
|
|
case DatasourceProviderType.LOCAL_FILE:
|
|
controller = LocalFileDatasourcePluginProviderController(
|
|
entity=provider_entity.declaration,
|
|
plugin_id=provider_entity.plugin_id,
|
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
|
tenant_id=tenant_id,
|
|
)
|
|
case _:
|
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
|
|
|
if controller:
|
|
datasource_plugin_providers[provider_id] = controller
|
|
|
|
if controller is None:
|
|
raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.")
|
|
|
|
return controller
|
|
|
|
@classmethod
|
|
def get_datasource_runtime(
|
|
cls,
|
|
provider_id: str,
|
|
datasource_name: str,
|
|
tenant_id: str,
|
|
datasource_type: DatasourceProviderType,
|
|
) -> DatasourcePlugin:
|
|
"""
|
|
get the datasource runtime
|
|
|
|
:param provider_type: the type of the provider
|
|
:param provider_id: the id of the provider
|
|
:param datasource_name: the name of the datasource
|
|
:param tenant_id: the tenant id
|
|
|
|
:return: the datasource plugin
|
|
"""
|
|
return cls.get_datasource_plugin_provider(
|
|
provider_id,
|
|
tenant_id,
|
|
datasource_type,
|
|
).get_datasource(datasource_name)
|
|
|
|
@classmethod
|
|
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
|
|
datasource_runtime = cls.get_datasource_runtime(
|
|
provider_id=provider_id,
|
|
datasource_name=datasource_name,
|
|
tenant_id=tenant_id,
|
|
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
|
)
|
|
return datasource_runtime.get_icon_url(tenant_id)
|
|
|
|
@classmethod
|
|
def stream_online_results(
|
|
cls,
|
|
*,
|
|
user_id: str,
|
|
datasource_name: str,
|
|
datasource_type: str,
|
|
provider_id: str,
|
|
tenant_id: str,
|
|
provider: str,
|
|
plugin_id: str,
|
|
credential_id: str,
|
|
datasource_param: DatasourceParameter | None = None,
|
|
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
|
) -> Generator[DatasourceMessage, None, Any]:
|
|
"""
|
|
Pull-based streaming of domain messages from datasource plugins.
|
|
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
|
|
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
|
|
"""
|
|
ds_type = DatasourceProviderType.value_of(datasource_type)
|
|
runtime = cls.get_datasource_runtime(
|
|
provider_id=provider_id,
|
|
datasource_name=datasource_name,
|
|
tenant_id=tenant_id,
|
|
datasource_type=ds_type,
|
|
)
|
|
|
|
dsp_service = DatasourceProviderService()
|
|
credentials = dsp_service.get_datasource_credentials(
|
|
tenant_id=tenant_id,
|
|
provider=provider,
|
|
plugin_id=plugin_id,
|
|
credential_id=credential_id,
|
|
)
|
|
|
|
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
|
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
|
|
if credentials:
|
|
doc_runtime.runtime.credentials = credentials
|
|
if datasource_param is None:
|
|
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
|
|
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
|
|
user_id=user_id,
|
|
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
|
workspace_id=datasource_param.workspace_id,
|
|
page_id=datasource_param.page_id,
|
|
type=datasource_param.type,
|
|
),
|
|
provider_type=ds_type,
|
|
)
|
|
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
|
|
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
|
|
if credentials:
|
|
drive_runtime.runtime.credentials = credentials
|
|
if online_drive_request is None:
|
|
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
|
|
inner_gen = drive_runtime.online_drive_download_file(
|
|
user_id=user_id,
|
|
request=OnlineDriveDownloadFileRequest(
|
|
id=online_drive_request.id,
|
|
bucket=online_drive_request.bucket,
|
|
),
|
|
provider_type=ds_type,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
|
|
|
|
# Bridge through to caller while preserving generator return contract
|
|
yield from inner_gen
|
|
# No structured final data here; node/adapter will assemble outputs
|
|
return {}
|
|
|
|
@classmethod
|
|
def stream_node_events(
|
|
cls,
|
|
*,
|
|
node_id: str,
|
|
user_id: str,
|
|
datasource_name: str,
|
|
datasource_type: str,
|
|
provider_id: str,
|
|
tenant_id: str,
|
|
provider: str,
|
|
plugin_id: str,
|
|
credential_id: str,
|
|
parameters_for_log: dict[str, Any],
|
|
datasource_info: dict[str, Any],
|
|
variable_pool: Any,
|
|
datasource_param: DatasourceParameter | None = None,
|
|
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
|
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
|
|
ds_type = DatasourceProviderType.value_of(datasource_type)
|
|
|
|
messages = cls.stream_online_results(
|
|
user_id=user_id,
|
|
datasource_name=datasource_name,
|
|
datasource_type=datasource_type,
|
|
provider_id=provider_id,
|
|
tenant_id=tenant_id,
|
|
provider=provider,
|
|
plugin_id=plugin_id,
|
|
credential_id=credential_id,
|
|
datasource_param=datasource_param,
|
|
online_drive_request=online_drive_request,
|
|
)
|
|
|
|
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
|
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
|
|
)
|
|
|
|
variables: dict[str, Any] = {}
|
|
file_out: File | None = None
|
|
|
|
for message in transformed:
|
|
mtype = message.type
|
|
if mtype in {
|
|
DatasourceMessage.MessageType.IMAGE_LINK,
|
|
DatasourceMessage.MessageType.BINARY_LINK,
|
|
DatasourceMessage.MessageType.IMAGE,
|
|
}:
|
|
wanted_ds_type = ds_type in {
|
|
DatasourceProviderType.ONLINE_DRIVE,
|
|
DatasourceProviderType.ONLINE_DOCUMENT,
|
|
}
|
|
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
|
|
url = message.message.text
|
|
|
|
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
|
with session_factory.create_session() as session:
|
|
stmt = select(ToolFile).where(
|
|
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
|
|
)
|
|
datasource_file = session.scalar(stmt)
|
|
if not datasource_file:
|
|
raise ValueError(
|
|
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
|
|
)
|
|
mime_type = datasource_file.mimetype
|
|
if datasource_file is not None:
|
|
mapping = {
|
|
"tool_file_id": datasource_file_id,
|
|
"type": file_factory.get_file_type_by_mime_type(mime_type),
|
|
"transfer_method": FileTransferMethod.TOOL_FILE,
|
|
"url": url,
|
|
}
|
|
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
|
|
elif mtype == DatasourceMessage.MessageType.TEXT:
|
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
|
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
|
|
elif mtype == DatasourceMessage.MessageType.LINK:
|
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
|
yield StreamChunkEvent(
|
|
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
|
|
)
|
|
elif mtype == DatasourceMessage.MessageType.VARIABLE:
|
|
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
|
name = message.message.variable_name
|
|
value = message.message.variable_value
|
|
if message.message.stream:
|
|
assert isinstance(value, str), "stream variable_value must be str"
|
|
variables[name] = variables.get(name, "") + value
|
|
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
|
|
else:
|
|
variables[name] = value
|
|
elif mtype == DatasourceMessage.MessageType.FILE:
|
|
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
|
|
f = message.meta.get("file")
|
|
if isinstance(f, File):
|
|
file_out = f
|
|
else:
|
|
pass
|
|
|
|
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
|
|
|
|
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
|
|
variable_pool.add([node_id, "file"], file_out)
|
|
|
|
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
|
outputs={**variables},
|
|
)
|
|
)
|
|
else:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs=parameters_for_log,
|
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
|
outputs={
|
|
"file": file_out,
|
|
"datasource_type": ds_type,
|
|
},
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
|
with session_factory.create_session() as session:
|
|
upload_file = (
|
|
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
|
)
|
|
if not upload_file:
|
|
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
|
|
|
file_info = File(
|
|
id=upload_file.id,
|
|
filename=upload_file.name,
|
|
extension="." + upload_file.extension,
|
|
mime_type=upload_file.mime_type,
|
|
tenant_id=tenant_id,
|
|
type=FileType.CUSTOM,
|
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
|
remote_url=upload_file.source_url,
|
|
related_id=upload_file.id,
|
|
size=upload_file.size,
|
|
storage_key=upload_file.key,
|
|
url=upload_file.source_url,
|
|
)
|
|
return file_info
|