Merge branch 'feat/rag-2' of https://github.com/langgenius/dify into feat/rag-2

This commit is contained in:
twwu
2025-09-03 20:33:01 +08:00
11 changed files with 56 additions and 52 deletions

View File

@@ -48,17 +48,17 @@ class WorkflowVariablesConfigManager:
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part)
variables_map.pop(last_part, None)
if value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
variables_map.pop(last_part)
variables_map.pop(last_part, None)
all_second_step_variables = list(variables_map.values())

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from configs import dify_config
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
@@ -10,14 +11,17 @@ from core.datasource.entities.datasource_entities import (
class DatasourcePlugin(ABC):
entity: DatasourceEntity
runtime: DatasourceRuntime
icon: str
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
icon: str,
) -> None:
self.entity = entity
self.runtime = runtime
self.icon = icon
@abstractmethod
def datasource_provider_type(self) -> str:
@@ -30,4 +34,8 @@ class DatasourcePlugin(ABC):
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
icon=self.icon,
)
def get_icon_url(self, tenant_id: str) -> str:
return f"{dify_config.CONSOLE_API_URL}/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={self.icon}" # noqa: E501

View File

@@ -8,7 +8,6 @@ from core.datasource.entities.datasource_entities import (
class LocalFileDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
def __init__(
@@ -19,10 +18,12 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
super().__init__(entity, runtime, icon)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE
def get_icon_url(self, tenant_id: str) -> str:
return self.icon

View File

@@ -15,7 +15,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
@@ -28,9 +27,8 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
super().__init__(entity, runtime, icon)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def get_online_document_pages(

View File

@@ -15,7 +15,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
@@ -28,9 +27,8 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
super().__init__(entity, runtime, icon)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def online_drive_browse_files(

View File

@@ -13,7 +13,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
@@ -26,9 +25,8 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
super().__init__(entity, runtime, icon)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def get_website_crawl(

View File

@@ -75,14 +75,17 @@ class DatasourceNode(Node):
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
if not datasource_type:
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
if not datasource_type_segement:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = datasource_type.value
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if not datasource_info:
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if not datasource_info_segement:
raise DatasourceNodeError("Datasource info is not set")
datasource_info = datasource_info.value
datasource_info_value = datasource_info_segement.value
if not isinstance(datasource_info_value, dict):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
# get datasource runtime
try:
from core.datasource.datasource_manager import DatasourceManager
@@ -96,6 +99,7 @@ class DatasourceNode(Node):
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
except DatasourceNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
@@ -107,15 +111,7 @@ class DatasourceNode(Node):
)
)
# get parameters
datasource_parameters = datasource_runtime.entity.parameters
parameters_for_log = self._generate_parameters(
datasource_parameters=datasource_parameters,
variable_pool=variable_pool,
node_data=self._node_data,
for_log=True,
)
parameters_for_log = datasource_info
try:
datasource_provider_service = DatasourceProviderService()
@@ -123,7 +119,7 @@ class DatasourceNode(Node):
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=datasource_info.get("credential_id"),
credential_id=datasource_info.get("credential_id", ""),
)
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
@@ -134,9 +130,9 @@ class DatasourceNode(Node):
datasource_runtime.get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_info.get("workspace_id"),
page_id=datasource_info.get("page").get("page_id"),
type=datasource_info.get("page").get("type"),
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
),
provider_type=datasource_type,
)
@@ -154,7 +150,7 @@ class DatasourceNode(Node):
datasource_runtime.online_drive_download_file(
user_id=self.user_id,
request=OnlineDriveDownloadFileRequest(
id=datasource_info.get("id"),
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket"),
),
provider_type=datasource_type,
@@ -209,7 +205,7 @@ class DatasourceNode(Node):
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file_info": datasource_info,
"file": datasource_info,
"datasource_type": datasource_type,
},
)

View File

@@ -184,13 +184,13 @@ class WorkflowEntry:
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
if node_type != NodeType.DATASOURCE:
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
try:
# run node

View File

@@ -834,7 +834,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
provider_type=tool_info["provider_type"],
provider_id=tool_info["provider_id"],
)
elif self.node_type == NodeType.DATASOURCE.value and "datasource_info" in self.execution_metadata_dict:
datasource_info = self.execution_metadata_dict["datasource_info"]
extras["icon"] = datasource_info.get("icon")
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:

View File

@@ -7,7 +7,6 @@ from configs import dify_config
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
logger = logging.getLogger(__name__)

View File

@@ -929,19 +929,23 @@ class RagPipelineService:
else:
return []
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
user_input_variables_keys = []
user_input_variables = []
for key, value in datasource_parameters.items():
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
user_input_variables.append(variables_map.get(last_part, {}))
user_input_variables_keys.append(last_part)
elif value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
user_input_variables.append(variables_map.get(last_part, {}))
user_input_variables_keys.append(last_part)
for key, value in variables_map.items():
if key in user_input_variables_keys:
user_input_variables.append(value)
return user_input_variables
@@ -972,17 +976,17 @@ class RagPipelineService:
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part)
variables_map.pop(last_part, None)
elif value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
variables_map.pop(last_part)
variables_map.pop(last_part, None)
all_second_step_variables = list(variables_map.values())
datasource_provider_variables = [
item