mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
Merge branch 'feat/rag-2' of https://github.com/langgenius/dify into feat/rag-2
This commit is contained in:
@@ -48,17 +48,17 @@ class WorkflowVariablesConfigManager:
|
||||
if datasource_node_data:
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
|
||||
for key, value in datasource_parameters.items():
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
variables_map.pop(last_part)
|
||||
variables_map.pop(last_part, None)
|
||||
if value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
variables_map.pop(last_part)
|
||||
variables_map.pop(last_part, None)
|
||||
|
||||
all_second_step_variables = list(variables_map.values())
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from configs import dify_config
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
@@ -10,14 +11,17 @@ from core.datasource.entities.datasource_entities import (
|
||||
class DatasourcePlugin(ABC):
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
icon: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
icon: str,
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
self.icon = icon
|
||||
|
||||
@abstractmethod
|
||||
def datasource_provider_type(self) -> str:
|
||||
@@ -30,4 +34,8 @@ class DatasourcePlugin(ABC):
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
icon=self.icon,
|
||||
)
|
||||
|
||||
def get_icon_url(self, tenant_id: str) -> str:
|
||||
return f"{dify_config.CONSOLE_API_URL}/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={self.icon}" # noqa: E501
|
||||
|
||||
@@ -8,7 +8,6 @@ from core.datasource.entities.datasource_entities import (
|
||||
|
||||
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
@@ -19,10 +18,12 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
super().__init__(entity, runtime, icon)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def get_icon_url(self, tenant_id: str) -> str:
|
||||
return self.icon
|
||||
|
||||
@@ -15,7 +15,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
@@ -28,9 +27,8 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
super().__init__(entity, runtime, icon)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_online_document_pages(
|
||||
|
||||
@@ -15,7 +15,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
@@ -28,9 +27,8 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
super().__init__(entity, runtime, icon)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def online_drive_browse_files(
|
||||
|
||||
@@ -13,7 +13,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
@@ -26,9 +25,8 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
super().__init__(entity, runtime, icon)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_website_crawl(
|
||||
|
||||
@@ -75,14 +75,17 @@ class DatasourceNode(Node):
|
||||
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
if not datasource_type:
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
if not datasource_type_segement:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = datasource_type.value
|
||||
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if not datasource_info:
|
||||
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if not datasource_info_segement:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info = datasource_info.value
|
||||
datasource_info_value = datasource_info_segement.value
|
||||
if not isinstance(datasource_info_value, dict):
|
||||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
@@ -96,6 +99,7 @@ class DatasourceNode(Node):
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
except DatasourceNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
@@ -107,15 +111,7 @@ class DatasourceNode(Node):
|
||||
)
|
||||
)
|
||||
|
||||
# get parameters
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
|
||||
parameters_for_log = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
)
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@@ -123,7 +119,7 @@ class DatasourceNode(Node):
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=datasource_info.get("credential_id"),
|
||||
credential_id=datasource_info.get("credential_id", ""),
|
||||
)
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
@@ -134,9 +130,9 @@ class DatasourceNode(Node):
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_info.get("workspace_id"),
|
||||
page_id=datasource_info.get("page").get("page_id"),
|
||||
type=datasource_info.get("page").get("type"),
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
@@ -154,7 +150,7 @@ class DatasourceNode(Node):
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=datasource_info.get("id"),
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
@@ -209,7 +205,7 @@ class DatasourceNode(Node):
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file_info": datasource_info,
|
||||
"file": datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -184,13 +184,13 @@ class WorkflowEntry:
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
if node_type != NodeType.DATASOURCE:
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
|
||||
@@ -834,7 +834,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
provider_type=tool_info["provider_type"],
|
||||
provider_id=tool_info["provider_id"],
|
||||
)
|
||||
|
||||
elif self.node_type == NodeType.DATASOURCE.value and "datasource_info" in self.execution_metadata_dict:
|
||||
datasource_info = self.execution_metadata_dict["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
|
||||
@@ -7,7 +7,6 @@ from configs import dify_config
|
||||
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -929,19 +929,23 @@ class RagPipelineService:
|
||||
else:
|
||||
return []
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
|
||||
user_input_variables_keys = []
|
||||
user_input_variables = []
|
||||
for key, value in datasource_parameters.items():
|
||||
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
user_input_variables.append(variables_map.get(last_part, {}))
|
||||
user_input_variables_keys.append(last_part)
|
||||
elif value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
user_input_variables.append(variables_map.get(last_part, {}))
|
||||
user_input_variables_keys.append(last_part)
|
||||
for key, value in variables_map.items():
|
||||
if key in user_input_variables_keys:
|
||||
user_input_variables.append(value)
|
||||
|
||||
return user_input_variables
|
||||
|
||||
@@ -972,17 +976,17 @@ class RagPipelineService:
|
||||
if datasource_node_data:
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
|
||||
for key, value in datasource_parameters.items():
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
variables_map.pop(last_part)
|
||||
variables_map.pop(last_part, None)
|
||||
elif value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
variables_map.pop(last_part)
|
||||
variables_map.pop(last_part, None)
|
||||
all_second_step_variables = list(variables_map.values())
|
||||
datasource_provider_variables = [
|
||||
item
|
||||
|
||||
Reference in New Issue
Block a user