mirror of
https://github.com/langgenius/dify.git
synced 2026-03-15 04:07:01 +00:00
Compare commits
2 Commits
move-knowl
...
internal-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61fab889bb | ||
|
|
2fba4931ec |
@@ -22,10 +22,10 @@ APP_WEB_URL=http://localhost:3000
|
||||
# Files URL
|
||||
FILES_URL=http://localhost:5001
|
||||
|
||||
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
|
||||
# Set this to the internal Docker service URL for proper plugin file access.
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
# INTERNAL_FILES_URL is used by services running in Docker to reach the API file endpoints.
|
||||
# For Docker Desktop (Mac/Windows), use http://host.docker.internal:5001 when the API runs on the host.
|
||||
# For Docker Compose on Linux, use http://api:5001 when the API runs inside the Docker network.
|
||||
INTERNAL_FILES_URL=http://host.docker.internal:5001
|
||||
|
||||
# TRIGGER URL
|
||||
TRIGGER_URL=http://localhost:5001
|
||||
|
||||
@@ -30,8 +30,6 @@ ignore_imports =
|
||||
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine
|
||||
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.loop.loop_node -> core.workflow.node_factory
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
|
||||
|
||||
@@ -45,6 +43,7 @@ forbidden_modules =
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
@@ -91,8 +90,9 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.loop.loop_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.agent.agent_node -> core.model_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.provider_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
dify_graph.nodes.llm.protocols -> core.model_manager
|
||||
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -100,6 +100,9 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.entities
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -107,10 +110,12 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
@@ -119,12 +124,17 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
dify_graph.nodes.llm.node -> models.dataset
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.signature
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
dify_graph.nodes.tool.tool_node -> services
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
|
||||
@@ -25,8 +25,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.enums import NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@@ -509,7 +508,11 @@ class AppListApi(Resource):
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
trigger_node_types = TRIGGER_NODE_TYPES
|
||||
trigger_node_types = {
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
node_id = None
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -10,7 +11,6 @@ from controllers.console.wraps import account_initialization_required, edit_perm
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
@@ -19,6 +19,11 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
@@ -112,10 +117,9 @@ class AppMCPServerController(Resource):
|
||||
|
||||
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||
if payload.status:
|
||||
try:
|
||||
server.status = AppMCPServerStatus(payload.status)
|
||||
except ValueError:
|
||||
if payload.status not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = payload.status
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
from core.trigger.debug.event_selectors import (
|
||||
TriggerDebugEvent,
|
||||
TriggerDebugEventPoller,
|
||||
@@ -1210,7 +1209,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
|
||||
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
|
||||
event: TriggerDebugEvent | None = None
|
||||
# for schedule trigger, when run single node, just execute directly
|
||||
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
|
||||
if node_type == NodeType.TRIGGER_SCHEDULE:
|
||||
event = TriggerDebugEvent(
|
||||
workflow_args={},
|
||||
node_id=node_id,
|
||||
|
||||
@@ -43,7 +43,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
@@ -216,7 +216,7 @@ class AccountInitApi(Resource):
|
||||
db.session.query(InvitationCode)
|
||||
.where(
|
||||
InvitationCode.code == args.invitation_code,
|
||||
InvitationCode.status == InvitationCodeStatus.UNUSED,
|
||||
InvitationCode.status == "unused",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -224,7 +224,7 @@ class AccountInitApi(Resource):
|
||||
if not invitation_code:
|
||||
raise InvalidInvitationCodeError()
|
||||
|
||||
invitation_code.status = InvitationCodeStatus.USED
|
||||
invitation_code.status = "used"
|
||||
invitation_code.used_at = naive_utc_now()
|
||||
invitation_code.used_by_tenant_id = account.current_tenant_id
|
||||
invitation_code.used_by_account_id = account.id
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Literal
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -170,20 +169,6 @@ register_enum_models(
|
||||
)
|
||||
|
||||
|
||||
def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
|
||||
"""
|
||||
Read the uploaded file and validate its actual size before delegating to the plugin service.
|
||||
|
||||
FileStorage.content_length is not reliable for multipart test uploads and may be zero even when
|
||||
content exists, so the controllers validate against the loaded bytes instead.
|
||||
"""
|
||||
content = file.read()
|
||||
if len(content) > max_size:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@@ -299,7 +284,12 @@ class PluginUploadFromPkgApi(Resource):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
file = request.files["pkg"]
|
||||
content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
try:
|
||||
response = PluginService.upload_pkg(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
@@ -338,7 +328,12 @@ class PluginUploadFromBundleApi(Resource):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
file = request.files["bundle"]
|
||||
content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE)
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
try:
|
||||
response = PluginService.upload_bundle(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
|
||||
@@ -6,13 +6,13 @@ from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
@@ -23,6 +22,7 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
class AgentMaxIterationError(Exception):
|
||||
"""Raised when an agent runner exceeds the configured max iteration count."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
@@ -5,7 +5,6 @@ from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
@@ -26,6 +25,7 @@ from dify_graph.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,7 +69,7 @@ from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@@ -357,7 +357,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node succeeded events."""
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [BuiltinNodeTypes.ANSWER, BuiltinNodeTypes.END, BuiltinNodeTypes.LLM]:
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
|
||||
self._recorded_files.extend(
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
@@ -48,13 +48,12 @@ from core.app.entities.task_entities import (
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@@ -443,7 +442,7 @@ class WorkflowResponseConverter:
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
) -> NodeStartStreamResponse | None:
|
||||
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
snapshot = self._store_snapshot(event)
|
||||
@@ -465,13 +464,13 @@ class WorkflowResponseConverter:
|
||||
)
|
||||
|
||||
try:
|
||||
if event.node_type == BuiltinNodeTypes.TOOL:
|
||||
if event.node_type == NodeType.TOOL:
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == BuiltinNodeTypes.DATASOURCE:
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
@@ -480,7 +479,7 @@ class WorkflowResponseConverter:
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == TRIGGER_PLUGIN_NODE_TYPE:
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
@@ -497,7 +496,7 @@ class WorkflowResponseConverter:
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
) -> NodeFinishStreamResponse | None:
|
||||
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
snapshot = self._pop_snapshot(event.node_execution_id)
|
||||
@@ -555,7 +554,7 @@ class WorkflowResponseConverter:
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
) -> NodeRetryStreamResponse | None:
|
||||
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
|
||||
@@ -613,7 +612,7 @@ class WorkflowResponseConverter:
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
@@ -636,7 +635,7 @@ class WorkflowResponseConverter:
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
created_at=int(time.time()),
|
||||
@@ -663,7 +662,7 @@ class WorkflowResponseConverter:
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
outputs=new_outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
@@ -693,7 +692,7 @@ class WorkflowResponseConverter:
|
||||
data=LoopNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
@@ -716,7 +715,7 @@ class WorkflowResponseConverter:
|
||||
data=LoopNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
# The `pre_loop_output` field is not utilized by the frontend.
|
||||
@@ -745,7 +744,7 @@ class WorkflowResponseConverter:
|
||||
data=LoopNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_title,
|
||||
outputs=new_outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.graph_init_params import GraphInitParams
|
||||
from dify_graph.enums import WorkflowType
|
||||
@@ -274,8 +274,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if start_node_id is None:
|
||||
start_node_id = get_default_root_node_id(graph_config)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
|
||||
|
||||
if not graph:
|
||||
|
||||
@@ -3,10 +3,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@@ -32,9 +29,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
@@ -68,6 +63,7 @@ from dify_graph.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.graph_events.graph import GraphRunAbortedEvent
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
@@ -141,9 +137,6 @@ class WorkflowBasedAppRunner:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
if root_node_id is None:
|
||||
root_node_id = get_default_root_node_id(graph_config)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
@@ -315,7 +308,7 @@ class WorkflowBasedAppRunner:
|
||||
# Get node class
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
@@ -343,18 +336,6 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
@staticmethod
|
||||
def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None:
|
||||
raw_agent_strategy = event.extras.get("agent_strategy")
|
||||
if raw_agent_strategy is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return AgentStrategyInfo.model_validate(raw_agent_strategy)
|
||||
except ValidationError:
|
||||
logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True)
|
||||
return None
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
@@ -440,7 +421,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
agent_strategy=self._build_agent_strategy_info(event),
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
@@ -509,9 +490,7 @@ class WorkflowBasedAppRunner:
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=[
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
],
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .agent_strategy import AgentStrategyInfo
|
||||
|
||||
__all__ = ["AgentStrategyInfo"]
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class AgentStrategyInfo(BaseModel):
|
||||
name: str
|
||||
icon: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
@@ -5,12 +5,13 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.nodes import NodeType
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
@@ -313,7 +314,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
start_at: datetime
|
||||
agent_strategy: AgentStrategyInfo | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
@@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
agent_strategy: AgentStrategyInfo | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.conversation_variable_updater import ConversationVariableUpdater
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
||||
@@ -22,7 +22,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunSucceededEvent):
|
||||
return
|
||||
if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER:
|
||||
if event.node_type != NodeType.VARIABLE_ASSIGNER:
|
||||
return
|
||||
if self.graph_runtime_state is None:
|
||||
return
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing_extensions import override
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
@@ -113,11 +113,11 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
case NodeType.LLM:
|
||||
return cast("LLMNode", node).model_instance
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
case NodeType.PARAMETER_EXTRACTOR:
|
||||
return cast("ParameterExtractorNode", node).model_instance
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
case NodeType.QUESTION_CLASSIFIER:
|
||||
return cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -16,7 +16,7 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphNodeEventBase
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@@ -74,13 +74,16 @@ class ObservabilityLayer(GraphEngineLayer):
|
||||
def _build_parser_registry(self) -> None:
|
||||
"""Initialize parser registry for node types."""
|
||||
self._parsers = {
|
||||
BuiltinNodeTypes.TOOL: ToolNodeOTelParser(),
|
||||
BuiltinNodeTypes.LLM: LLMNodeOTelParser(),
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
|
||||
NodeType.TOOL: ToolNodeOTelParser(),
|
||||
NodeType.LLM: LLMNodeOTelParser(),
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
|
||||
}
|
||||
|
||||
def _get_parser(self, node: Node) -> NodeOTelParser:
|
||||
return self._parsers.get(node.node_type, self._default_parser)
|
||||
node_type = getattr(node, "node_type", None)
|
||||
if isinstance(node_type, NodeType):
|
||||
return self._parsers.get(node_type, self._default_parser)
|
||||
return self._default_parser
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
|
||||
@@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
@@ -58,7 +58,7 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@@ -302,11 +302,11 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
|
||||
):
|
||||
try:
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
|
||||
|
||||
@@ -155,8 +155,8 @@ def wrap_span_metadata(metadata, **kwargs):
|
||||
return metadata
|
||||
|
||||
|
||||
# Mapping from built-in node type strings to OpenInference span kinds.
|
||||
# Node types not listed here default to CHAIN.
|
||||
# Mapping from NodeType string values to OpenInference span kinds.
|
||||
# NodeType values not listed here default to CHAIN.
|
||||
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
|
||||
"llm": OpenInferenceSpanKindValues.LLM,
|
||||
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
|
||||
@@ -168,7 +168,7 @@ _NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
|
||||
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
|
||||
"""Return the OpenInference span kind for a given workflow node type.
|
||||
|
||||
Covers every built-in node type string. Nodes that do not have a
|
||||
Covers every ``NodeType`` enum value. Nodes that do not have a
|
||||
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
|
||||
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
|
||||
"""
|
||||
|
||||
@@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
@@ -141,7 +141,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
|
||||
@@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@@ -163,7 +163,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
@@ -197,7 +197,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
run_type = LangSmithRunType.retriever
|
||||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
||||
@@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
@@ -145,10 +145,10 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
"app_name": node.title,
|
||||
}
|
||||
|
||||
if node.node_type in (BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER):
|
||||
if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER):
|
||||
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
|
||||
attributes.update(llm_attributes)
|
||||
elif node.node_type == BuiltinNodeTypes.HTTP_REQUEST:
|
||||
elif node.node_type == NodeType.HTTP_REQUEST:
|
||||
inputs = node.process_data # contains request URL
|
||||
|
||||
if not inputs:
|
||||
@@ -180,9 +180,9 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
# End node span
|
||||
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
|
||||
outputs = json.loads(node.outputs) if node.outputs else {}
|
||||
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
outputs = self._parse_knowledge_retrieval_outputs(outputs)
|
||||
elif node.node_type == BuiltinNodeTypes.LLM:
|
||||
elif node.node_type == NodeType.LLM:
|
||||
outputs = outputs.get("text", outputs)
|
||||
node_span.end(
|
||||
outputs=outputs,
|
||||
@@ -471,13 +471,13 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
def _get_node_span_type(self, node_type: str) -> str:
|
||||
"""Map Dify node types to MLflow span types"""
|
||||
node_type_mapping = {
|
||||
BuiltinNodeTypes.LLM: SpanType.LLM,
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: SpanType.LLM,
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
|
||||
BuiltinNodeTypes.TOOL: SpanType.TOOL,
|
||||
BuiltinNodeTypes.CODE: SpanType.TOOL,
|
||||
BuiltinNodeTypes.HTTP_REQUEST: SpanType.TOOL,
|
||||
BuiltinNodeTypes.AGENT: SpanType.AGENT,
|
||||
NodeType.LLM: SpanType.LLM,
|
||||
NodeType.QUESTION_CLASSIFIER: SpanType.LLM,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
|
||||
NodeType.TOOL: SpanType.TOOL,
|
||||
NodeType.CODE: SpanType.TOOL,
|
||||
NodeType.HTTP_REQUEST: SpanType.TOOL,
|
||||
NodeType.AGENT: SpanType.AGENT,
|
||||
}
|
||||
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@@ -187,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
|
||||
@@ -27,7 +27,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from dify_graph.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
from dify_graph.nodes import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@@ -179,7 +179,7 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
if node_span:
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
self._record_llm_metrics(node_execution)
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
|
||||
@@ -192,15 +192,15 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
) -> SpanData | None:
|
||||
"""Build span for different node types"""
|
||||
try:
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
return TencentSpanBuilder.build_workflow_llm_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return TencentSpanBuilder.build_workflow_retrieval_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
return TencentSpanBuilder.build_workflow_tool_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
@@ -175,7 +175,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
@@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
instruction=instruction, # instruct with variables are not supported
|
||||
)
|
||||
node_data_dict = node_data.model_dump()
|
||||
node_data_dict["type"] = BuiltinNodeTypes.PARAMETER_EXTRACTOR
|
||||
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
|
||||
execution = workflow_service.run_free_workflow_node(
|
||||
node_data_dict,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -9,8 +9,8 @@ from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from dify_graph.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .index_processor_factory import IndexProcessorFactory
|
||||
|
||||
@@ -56,18 +56,18 @@ from core.rag.retrieval.template_prompts import (
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import (
|
||||
KnowledgeRetrievalRequest,
|
||||
Source,
|
||||
SourceChildChunk,
|
||||
SourceMetadata,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.knowledge_retrieval import exc
|
||||
from dify_graph.repositories.rag_retrieval_protocol import (
|
||||
KnowledgeRetrievalRequest,
|
||||
Source,
|
||||
SourceChildChunk,
|
||||
SourceMetadata,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
@@ -18,7 +18,7 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
@@ -146,7 +146,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
index=db_model.index,
|
||||
predecessor_node_id=db_model.predecessor_node_id,
|
||||
node_id=db_model.node_id,
|
||||
node_type=db_model.node_type,
|
||||
node_type=NodeType(db_model.node_type),
|
||||
title=db_model.title,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import OutputVariableEntity
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
|
||||
@@ -51,7 +51,7 @@ class WorkflowToolConfigurationUtils:
|
||||
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
|
||||
nodes = graph.get("nodes", [])
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == BuiltinNodeTypes.HUMAN_INPUT:
|
||||
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
|
||||
raise WorkflowToolHumanInputNotSupportedError()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from typing import Final
|
||||
|
||||
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
|
||||
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
|
||||
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
|
||||
|
||||
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
TRIGGER_WEBHOOK_NODE_TYPE,
|
||||
TRIGGER_SCHEDULE_NODE_TYPE,
|
||||
TRIGGER_PLUGIN_NODE_TYPE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_trigger_node_type(node_type: str) -> bool:
|
||||
return node_type in TRIGGER_NODE_TYPES
|
||||
@@ -11,11 +11,6 @@ from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.trigger.constants import (
|
||||
TRIGGER_PLUGIN_NODE_TYPE,
|
||||
TRIGGER_SCHEDULE_NODE_TYPE,
|
||||
TRIGGER_WEBHOOK_NODE_TYPE,
|
||||
)
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import (
|
||||
PluginTriggerDebugEvent,
|
||||
@@ -24,9 +19,10 @@ from core.trigger.debug.events import (
|
||||
build_plugin_pool_key,
|
||||
build_webhook_pool_key,
|
||||
)
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
|
||||
from libs.schedule_utils import calculate_next_run_at
|
||||
@@ -210,19 +206,21 @@ def create_event_poller(
|
||||
if not node_config:
|
||||
raise ValueError("Node data not found for node %s", node_id)
|
||||
node_type = draft_workflow.get_node_type_from_node_config(node_config)
|
||||
if node_type == TRIGGER_PLUGIN_NODE_TYPE:
|
||||
return PluginTriggerDebugEventPoller(
|
||||
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
|
||||
)
|
||||
if node_type == TRIGGER_WEBHOOK_NODE_TYPE:
|
||||
return WebhookTriggerDebugEventPoller(
|
||||
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
|
||||
)
|
||||
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
|
||||
return ScheduleTriggerDebugEventPoller(
|
||||
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
|
||||
)
|
||||
raise ValueError("unable to create event poller for node type %s", node_type)
|
||||
match node_type:
|
||||
case NodeType.TRIGGER_PLUGIN:
|
||||
return PluginTriggerDebugEventPoller(
|
||||
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
|
||||
)
|
||||
case NodeType.TRIGGER_WEBHOOK:
|
||||
return WebhookTriggerDebugEventPoller(
|
||||
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
|
||||
)
|
||||
case NodeType.TRIGGER_SCHEDULE:
|
||||
return ScheduleTriggerDebugEventPoller(
|
||||
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
|
||||
)
|
||||
case _:
|
||||
raise ValueError("unable to create event poller for node type %s", node_type)
|
||||
|
||||
|
||||
def select_trigger_debug_events(
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
"""Core workflow package."""
|
||||
from .node_factory import DifyNodeFactory
|
||||
from .workflow_entry import WorkflowEntry
|
||||
|
||||
__all__ = ["DifyNodeFactory", "WorkflowEntry"]
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
from collections.abc import Callable, Iterator, Mapping, MutableMapping
|
||||
from functools import lru_cache
|
||||
from types import MappingProxyType
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -12,6 +8,7 @@ from typing_extensions import override
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
@@ -20,19 +17,15 @@ from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
|
||||
from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
PluginAgentStrategyPresentationProvider,
|
||||
PluginAgentStrategyResolver,
|
||||
)
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
from dify_graph.graph.graph import NodeFactory
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
@@ -46,6 +39,7 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import build_http_request_config
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
@@ -59,122 +53,6 @@ if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
_START_NODE_TYPES: frozenset[NodeType] = frozenset(
|
||||
(BuiltinNodeTypes.START, BuiltinNodeTypes.DATASOURCE, *TRIGGER_NODE_TYPES)
|
||||
)
|
||||
|
||||
|
||||
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
|
||||
package = importlib.import_module(package_name)
|
||||
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
|
||||
if module_name in excluded_modules:
|
||||
continue
|
||||
importlib.import_module(module_name)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def register_nodes() -> None:
|
||||
"""Import production node modules so they self-register with ``Node``."""
|
||||
_import_node_package("dify_graph.nodes")
|
||||
_import_node_package("core.workflow.nodes")
|
||||
|
||||
|
||||
def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
"""Return a read-only snapshot of the current production node registry.
|
||||
|
||||
The workflow layer owns node bootstrap because it must compose built-in
|
||||
`dify_graph.nodes.*` implementations with workflow-local nodes under
|
||||
`core.workflow.nodes.*`. Keeping this import side effect here avoids
|
||||
reintroducing registry bootstrapping into lower-level graph primitives.
|
||||
"""
|
||||
register_nodes()
|
||||
return {node_type: MappingProxyType(version_map) for node_type, version_map in Node._registry.items()}
|
||||
|
||||
|
||||
def is_start_node_type(node_type: NodeType) -> bool:
|
||||
"""Return True when the node type can serve as a workflow entry point."""
|
||||
return node_type in _START_NODE_TYPES
|
||||
|
||||
|
||||
def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str:
|
||||
"""Resolve the default entry node for a persisted top-level workflow graph.
|
||||
|
||||
This workflow-layer helper depends on start-node semantics defined by
|
||||
`is_start_node_type`, so it intentionally lives next to the node registry
|
||||
instead of in the raw `dify_graph.entities.graph_config` schema module.
|
||||
"""
|
||||
nodes = graph_config.get("nodes")
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
continue
|
||||
|
||||
if node.get("type") == "custom-note":
|
||||
continue
|
||||
|
||||
node_id = node.get("id")
|
||||
data = node.get("data")
|
||||
if not isinstance(node_id, str) or not isinstance(data, Mapping):
|
||||
continue
|
||||
|
||||
node_type = data.get("type")
|
||||
if isinstance(node_type, str) and is_start_node_type(node_type):
|
||||
return node_id
|
||||
|
||||
raise ValueError("Unable to determine default root node ID from workflow graph")
|
||||
|
||||
|
||||
class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Node]]]):
|
||||
"""Mutable dict-like view over the current node registry."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cached_snapshot: dict[NodeType, Mapping[str, type[Node]]] = {}
|
||||
self._cached_version = -1
|
||||
self._deleted: set[NodeType] = set()
|
||||
self._overrides: dict[NodeType, Mapping[str, type[Node]]] = {}
|
||||
|
||||
def _snapshot(self) -> dict[NodeType, Mapping[str, type[Node]]]:
|
||||
current_version = Node.get_registry_version()
|
||||
if self._cached_version != current_version:
|
||||
self._cached_snapshot = dict(get_node_type_classes_mapping())
|
||||
self._cached_version = current_version
|
||||
if not self._deleted and not self._overrides:
|
||||
return self._cached_snapshot
|
||||
|
||||
snapshot = {key: value for key, value in self._cached_snapshot.items() if key not in self._deleted}
|
||||
snapshot.update(self._overrides)
|
||||
return snapshot
|
||||
|
||||
def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]:
|
||||
return self._snapshot()[key]
|
||||
|
||||
def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None:
|
||||
self._deleted.discard(key)
|
||||
self._overrides[key] = value
|
||||
|
||||
def __delitem__(self, key: NodeType) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
if key in self._cached_snapshot:
|
||||
self._deleted.add(key)
|
||||
return
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterator[NodeType]:
|
||||
return iter(self._snapshot())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._snapshot())
|
||||
|
||||
|
||||
# Keep the canonical node-class mapping in the workflow layer that also bootstraps
|
||||
# legacy `core.workflow.nodes.*` registrations.
|
||||
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
|
||||
|
||||
|
||||
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
|
||||
|
||||
@@ -219,7 +97,10 @@ class DefaultWorkflowCodeExecutor:
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Default implementation of NodeFactory that resolves node classes from the live registry.
|
||||
Default implementation of NodeFactory that uses the traditional node mapping.
|
||||
|
||||
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
|
||||
and instantiating the appropriate node class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -246,6 +127,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = ToolFileManager
|
||||
self._http_request_file_manager = file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
|
||||
@@ -261,10 +143,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
|
||||
self._agent_strategy_resolver = PluginAgentStrategyResolver()
|
||||
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
|
||||
self._agent_runtime_support = AgentRuntimeSupport()
|
||||
self._agent_message_transformer = AgentMessageTransformer()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
|
||||
@@ -292,51 +170,55 @@ class DifyNodeFactory(NodeFactory):
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
NodeType.CODE: lambda: {
|
||||
"code_executor": self._code_executor,
|
||||
"code_limits": self._code_limits,
|
||||
},
|
||||
BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: {
|
||||
NodeType.TEMPLATE_TRANSFORM: lambda: {
|
||||
"template_renderer": self._template_renderer,
|
||||
"max_output_length": self._template_transform_max_output_length,
|
||||
},
|
||||
BuiltinNodeTypes.HTTP_REQUEST: lambda: {
|
||||
NodeType.HTTP_REQUEST: lambda: {
|
||||
"http_request_config": self._http_request_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
|
||||
"file_manager": self._http_request_file_manager,
|
||||
},
|
||||
BuiltinNodeTypes.HUMAN_INPUT: lambda: {
|
||||
NodeType.HUMAN_INPUT: lambda: {
|
||||
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
},
|
||||
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
NodeType.KNOWLEDGE_INDEX: lambda: {
|
||||
"index_processor": IndexProcessor(),
|
||||
"summary_index_service": SummaryIndex(),
|
||||
},
|
||||
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
|
||||
NodeType.DATASOURCE: lambda: {
|
||||
"datasource_manager": DatasourceManager,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
|
||||
"rag_retrieval": self._rag_retrieval,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: lambda: {
|
||||
"unstructured_api_config": self._document_extractor_unstructured_api_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
},
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=False,
|
||||
),
|
||||
BuiltinNodeTypes.TOOL: lambda: {
|
||||
NodeType.TOOL: lambda: {
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
|
||||
},
|
||||
BuiltinNodeTypes.AGENT: lambda: {
|
||||
"strategy_resolver": self._agent_strategy_resolver,
|
||||
"presentation_provider": self._agent_strategy_presentation_provider,
|
||||
"runtime_support": self._agent_runtime_support,
|
||||
"message_transformer": self._agent_message_transformer,
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
@@ -356,11 +238,16 @@ class DifyNodeFactory(NodeFactory):
|
||||
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
# Import lazily to avoid a module cycle with `node_resolution`, which
|
||||
# imports `get_node_type_classes_mapping` from this module at import time.
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
|
||||
def _build_llm_compatible_node_init_kwargs(
|
||||
self,
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from core.workflow.node_factory import get_node_type_classes_mapping
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
|
||||
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
node_mapping = get_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
@@ -1 +0,0 @@
|
||||
"""Workflow node implementations that remain under the legacy core.workflow namespace."""
|
||||
@@ -1,4 +0,0 @@
|
||||
from .agent_node import AgentNode
|
||||
from .entities import AgentNodeData
|
||||
|
||||
__all__ = ["AgentNode", "AgentNodeData"]
|
||||
@@ -1,188 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from .entities import AgentNodeData
|
||||
from .exceptions import (
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
)
|
||||
from .message_transformer import AgentMessageTransformer
|
||||
from .runtime_support import AgentRuntimeSupport
|
||||
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
node_type = BuiltinNodeTypes.AGENT
|
||||
|
||||
_strategy_resolver: AgentStrategyResolver
|
||||
_presentation_provider: AgentStrategyPresentationProvider
|
||||
_runtime_support: AgentRuntimeSupport
|
||||
_message_transformer: AgentMessageTransformer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
strategy_resolver: AgentStrategyResolver,
|
||||
presentation_provider: AgentStrategyPresentationProvider,
|
||||
runtime_support: AgentRuntimeSupport,
|
||||
message_transformer: AgentMessageTransformer,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._strategy_resolver = strategy_resolver
|
||||
self._presentation_provider = presentation_provider
|
||||
self._runtime_support = runtime_support
|
||||
self._message_transformer = message_transformer
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
event.extras["agent_strategy"] = {
|
||||
"name": self.node_data.agent_strategy_name,
|
||||
"icon": self._presentation_provider.get_icon(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
),
|
||||
}
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = self._strategy_resolver.resolve(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
parameters = self._runtime_support.build_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
invoke_from=dify_ctx.invoke_from,
|
||||
)
|
||||
parameters_for_log = self._runtime_support.build_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
invoke_from=dify_ctx.invoke_from,
|
||||
for_log=True,
|
||||
)
|
||||
credentials = self._runtime_support.build_credentials(parameters=parameters)
|
||||
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
yield from self._message_transformer.transform(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self._presentation_provider.get_icon(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
),
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
result: dict[str, Any] = {}
|
||||
typed_node_data = node_data
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
@@ -1,292 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayFileSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
|
||||
|
||||
|
||||
class AgentMessageTransformer:
|
||||
def transform(
|
||||
self,
|
||||
*,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == BuiltinNodeTypes.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
@@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
|
||||
|
||||
|
||||
class PluginAgentStrategyResolver(AgentStrategyResolver):
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
agent_strategy_provider_name: str,
|
||||
agent_strategy_name: str,
|
||||
) -> ResolvedAgentStrategy:
|
||||
return get_plugin_agent_strategy(
|
||||
tenant_id=tenant_id,
|
||||
agent_strategy_provider_name=agent_strategy_provider_name,
|
||||
agent_strategy_name=agent_strategy_name,
|
||||
)
|
||||
|
||||
|
||||
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
|
||||
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
try:
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name
|
||||
)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
return current_plugin.declaration.icon
|
||||
@@ -1,276 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from dify_graph.enums import SystemVariableKey
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
||||
from .strategy_protocols import ResolvedAgentStrategy
|
||||
|
||||
|
||||
class AgentRuntimeSupport:
|
||||
def build_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
invoke_from: Any,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id,
|
||||
app_id,
|
||||
entity,
|
||||
invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
|
||||
credentials = InvokeCredentials()
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if not tool.get("credential_id"):
|
||||
continue
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
except ValidationError:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
return credentials
|
||||
|
||||
def fetch_memory(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
app_id: str,
|
||||
model_instance: ModelInstance,
|
||||
) -> TokenBufferMemory | None:
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
@staticmethod
|
||||
def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]:
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value)
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
@staticmethod
|
||||
def _filter_mcp_type_tool(
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
@@ -1,39 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class ResolvedAgentStrategy(Protocol):
|
||||
meta_version: str | None
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]: ...
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
credentials: InvokeCredentials | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]: ...
|
||||
|
||||
|
||||
class AgentStrategyResolver(Protocol):
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
agent_strategy_provider_name: str,
|
||||
agent_strategy_name: str,
|
||||
) -> ResolvedAgentStrategy: ...
|
||||
|
||||
|
||||
class AgentStrategyPresentationProvider(Protocol):
|
||||
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ...
|
||||
@@ -1 +0,0 @@
|
||||
"""Datasource workflow node package."""
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Knowledge index workflow node package."""
|
||||
|
||||
KNOWLEDGE_INDEX_NODE_TYPE = "knowledge-index"
|
||||
|
||||
__all__ = ["KNOWLEDGE_INDEX_NODE_TYPE"]
|
||||
@@ -1 +0,0 @@
|
||||
"""Knowledge retrieval workflow node package."""
|
||||
@@ -1,3 +0,0 @@
|
||||
from .trigger_schedule_node import TriggerScheduleNode
|
||||
|
||||
__all__ = ["TriggerScheduleNode"]
|
||||
@@ -9,7 +9,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
@@ -22,8 +21,9 @@ from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLay
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_engine.protocols.command_channel import CommandChannel
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
@@ -253,7 +253,7 @@ class WorkflowEntry:
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
if node_type != BuiltinNodeTypes.DATASOURCE:
|
||||
if node_type != NodeType.DATASOURCE:
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
@@ -303,7 +303,7 @@ class WorkflowEntry:
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": BuiltinNodeTypes.START,
|
||||
"type": NodeType.START,
|
||||
"title": "Start",
|
||||
"desc": "Start",
|
||||
},
|
||||
@@ -339,11 +339,11 @@ class WorkflowEntry:
|
||||
# Create a minimal graph for single node execution
|
||||
graph_dict = cls._create_single_node_graph(node_id, node_data)
|
||||
|
||||
node_type = node_data.get("type", "")
|
||||
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||
node_type = NodeType(node_data.get("type", ""))
|
||||
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
|
||||
if not node_cls:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ The codebase enforces strict layering via import-linter:
|
||||
1. Create node class in `nodes/<node_type>/`
|
||||
1. Inherit from `BaseNode` or appropriate base class
|
||||
1. Implement `_run()` method
|
||||
1. Ensure the node module is importable under `nodes/<node_type>/`
|
||||
1. Register in `nodes/node_mapping.py`
|
||||
1. Add tests in `tests/unit_tests/dify_graph/nodes/`
|
||||
|
||||
### Implementing a Custom Layer
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_start_reason import WorkflowStartReason
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
|
||||
8
api/dify_graph/entities/agent.py
Normal file
8
api/dify_graph/entities/agent.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
"""Agent node strategy initialization data."""
|
||||
|
||||
name: str
|
||||
icon: str | None = None
|
||||
@@ -121,8 +121,6 @@ class DefaultValue(BaseModel):
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
|
||||
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
|
||||
# `type` therefore accepts downstream string node kinds; unknown node implementations
|
||||
# are rejected later when the node factory resolves the node registry.
|
||||
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
|
||||
# and persisted templates/workflows also carry undeclared compatibility keys such as
|
||||
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
|
||||
|
||||
@@ -48,7 +48,7 @@ class WorkflowNodeExecution(BaseModel):
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
predecessor_node_id: str | None = None # ID of the node that executed before this one
|
||||
node_id: str # ID of the node being executed
|
||||
node_type: NodeType # Type of node (e.g., start, llm, downstream response node)
|
||||
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
|
||||
title: str # Display title of the node
|
||||
|
||||
# Execution data
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from enum import StrEnum
|
||||
from typing import ClassVar, TypeAlias
|
||||
|
||||
|
||||
class NodeState(StrEnum):
|
||||
@@ -34,71 +33,56 @@ class SystemVariableKey(StrEnum):
|
||||
INVOKE_FROM = "invoke_from"
|
||||
|
||||
|
||||
NodeType: TypeAlias = str
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
KNOWLEDGE_INDEX = "knowledge-index"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
DATASOURCE = "datasource"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
LOOP_END = "loop-end"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
TRIGGER_WEBHOOK = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
HUMAN_INPUT = "human-input"
|
||||
|
||||
@property
|
||||
def is_trigger_node(self) -> bool:
|
||||
"""Check if this node type is a trigger node."""
|
||||
return self in [
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
]
|
||||
|
||||
class BuiltinNodeTypes:
|
||||
"""Built-in node type string constants.
|
||||
|
||||
`node_type` values are plain strings throughout the graph runtime. This namespace
|
||||
only exposes the built-in values shipped by `dify_graph`; downstream packages can
|
||||
use additional strings without extending this class.
|
||||
"""
|
||||
|
||||
START: ClassVar[NodeType] = "start"
|
||||
END: ClassVar[NodeType] = "end"
|
||||
ANSWER: ClassVar[NodeType] = "answer"
|
||||
LLM: ClassVar[NodeType] = "llm"
|
||||
KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval"
|
||||
IF_ELSE: ClassVar[NodeType] = "if-else"
|
||||
CODE: ClassVar[NodeType] = "code"
|
||||
TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform"
|
||||
QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier"
|
||||
HTTP_REQUEST: ClassVar[NodeType] = "http-request"
|
||||
TOOL: ClassVar[NodeType] = "tool"
|
||||
DATASOURCE: ClassVar[NodeType] = "datasource"
|
||||
VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner"
|
||||
LOOP: ClassVar[NodeType] = "loop"
|
||||
LOOP_START: ClassVar[NodeType] = "loop-start"
|
||||
LOOP_END: ClassVar[NodeType] = "loop-end"
|
||||
ITERATION: ClassVar[NodeType] = "iteration"
|
||||
ITERATION_START: ClassVar[NodeType] = "iteration-start"
|
||||
PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner"
|
||||
DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor"
|
||||
LIST_OPERATOR: ClassVar[NodeType] = "list-operator"
|
||||
AGENT: ClassVar[NodeType] = "agent"
|
||||
HUMAN_INPUT: ClassVar[NodeType] = "human-input"
|
||||
|
||||
|
||||
BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = (
|
||||
BuiltinNodeTypes.START,
|
||||
BuiltinNodeTypes.END,
|
||||
BuiltinNodeTypes.ANSWER,
|
||||
BuiltinNodeTypes.LLM,
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
|
||||
BuiltinNodeTypes.IF_ELSE,
|
||||
BuiltinNodeTypes.CODE,
|
||||
BuiltinNodeTypes.TEMPLATE_TRANSFORM,
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
BuiltinNodeTypes.HTTP_REQUEST,
|
||||
BuiltinNodeTypes.TOOL,
|
||||
BuiltinNodeTypes.DATASOURCE,
|
||||
BuiltinNodeTypes.VARIABLE_AGGREGATOR,
|
||||
BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR,
|
||||
BuiltinNodeTypes.LOOP,
|
||||
BuiltinNodeTypes.LOOP_START,
|
||||
BuiltinNodeTypes.LOOP_END,
|
||||
BuiltinNodeTypes.ITERATION,
|
||||
BuiltinNodeTypes.ITERATION_START,
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
||||
BuiltinNodeTypes.VARIABLE_ASSIGNER,
|
||||
BuiltinNodeTypes.DOCUMENT_EXTRACTOR,
|
||||
BuiltinNodeTypes.LIST_OPERATOR,
|
||||
BuiltinNodeTypes.AGENT,
|
||||
BuiltinNodeTypes.HUMAN_INPUT,
|
||||
)
|
||||
@property
|
||||
def is_start_node(self) -> bool:
|
||||
"""Check if this node type can serve as a workflow entry point."""
|
||||
return self in [
|
||||
NodeType.START,
|
||||
NodeType.DATASOURCE,
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
]
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
@@ -252,6 +236,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
|
||||
@@ -83,6 +83,50 @@ class Graph:
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: Mapping[str, NodeConfigDict],
|
||||
edge_configs: Sequence[Mapping[str, object]],
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find the root node ID if not specified.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param edge_configs: list of edge configurations
|
||||
:param root_node_id: explicitly specified root node ID
|
||||
:return: determined root node ID
|
||||
"""
|
||||
if root_node_id:
|
||||
if root_node_id not in node_configs_map:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
return root_node_id
|
||||
|
||||
# Find nodes with no incoming edges
|
||||
nodes_with_incoming: set[str] = set()
|
||||
for edge_config in edge_configs:
|
||||
target = edge_config.get("target")
|
||||
if isinstance(target, str):
|
||||
nodes_with_incoming.add(target)
|
||||
|
||||
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
|
||||
|
||||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid]["data"]
|
||||
if node_data.type.is_start_node:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
|
||||
|
||||
if not root_node_id:
|
||||
raise ValueError("Unable to determine root node ID")
|
||||
|
||||
return root_node_id
|
||||
|
||||
@classmethod
|
||||
def _build_edges(
|
||||
cls, edge_configs: list[dict[str, object]]
|
||||
@@ -257,15 +301,15 @@ class Graph:
|
||||
*,
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: NodeFactory,
|
||||
root_node_id: str,
|
||||
root_node_id: str | None = None,
|
||||
skip_validation: bool = False,
|
||||
) -> Graph:
|
||||
"""
|
||||
Initialize a graph with an explicit execution entry point.
|
||||
Initialize graph
|
||||
|
||||
:param graph_config: graph config containing nodes and edges
|
||||
:param node_factory: factory for creating node instances from config data
|
||||
:param root_node_id: active root node id
|
||||
:param root_node_id: root node id
|
||||
:return: graph instance
|
||||
"""
|
||||
# Parse configs
|
||||
@@ -283,8 +327,8 @@ class Graph:
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
if root_node_id not in node_configs_map:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
# Find root node
|
||||
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
|
||||
|
||||
# Build edges
|
||||
edges, in_edges, out_edges = cls._build_edges(edge_configs)
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType
|
||||
from dify_graph.enums import NodeExecutionType, NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
@@ -71,7 +71,7 @@ class _RootNodeValidator:
|
||||
"""Validates root node invariants."""
|
||||
|
||||
invalid_root_code: str = "INVALID_ROOT"
|
||||
container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START)
|
||||
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
root_node = graph.root_node
|
||||
@@ -86,7 +86,7 @@ class _RootNodeValidator:
|
||||
)
|
||||
return issues
|
||||
|
||||
node_type = root_node.node_type
|
||||
node_type = getattr(root_node, "node_type", None)
|
||||
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
@@ -114,9 +114,45 @@ class GraphValidator:
|
||||
raise GraphValidationError(issues)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _TriggerStartExclusivityValidator:
|
||||
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
|
||||
|
||||
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
start_node_id: str | None = None
|
||||
trigger_node_ids: list[str] = []
|
||||
|
||||
for node in graph.nodes.values():
|
||||
node_type = getattr(node, "node_type", None)
|
||||
if not isinstance(node_type, NodeType):
|
||||
continue
|
||||
|
||||
if node_type == NodeType.START:
|
||||
start_node_id = node.id
|
||||
elif node_type.is_trigger_node:
|
||||
trigger_node_ids.append(node.id)
|
||||
|
||||
if start_node_id and trigger_node_ids:
|
||||
trigger_list = ", ".join(trigger_node_ids)
|
||||
return [
|
||||
GraphValidationIssue(
|
||||
code=self.conflict_code,
|
||||
message=(
|
||||
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
|
||||
),
|
||||
node_id=start_node_id,
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||
_EdgeEndpointValidator(),
|
||||
_RootNodeValidator(),
|
||||
_TriggerStartExclusivityValidator(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,5 @@ of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
from .coordinator import ResponseStreamCoordinator
|
||||
from .session import RESPONSE_SESSION_NODE_TYPES
|
||||
|
||||
__all__ = ["RESPONSE_SESSION_NODE_TYPES", "ResponseStreamCoordinator"]
|
||||
__all__ = ["ResponseStreamCoordinator"]
|
||||
|
||||
@@ -3,34 +3,19 @@ Internal response session management for response coordinator.
|
||||
|
||||
This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
|
||||
`RESPONSE_SESSION_NODE_TYPES` is intentionally mutable so downstream applications
|
||||
can opt additional response-capable node types into session creation without
|
||||
patching the coordinator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, cast
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.nodes.answer.answer_node import AnswerNode
|
||||
from dify_graph.nodes.base.template import Template
|
||||
from dify_graph.nodes.end.end_node import EndNode
|
||||
from dify_graph.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from dify_graph.runtime.graph_runtime_state import NodeProtocol
|
||||
|
||||
|
||||
class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
|
||||
"""Structural contract required from nodes that can open a response session."""
|
||||
|
||||
def get_streaming_template(self) -> Template: ...
|
||||
|
||||
|
||||
RESPONSE_SESSION_NODE_TYPES: list[NodeType] = [
|
||||
BuiltinNodeTypes.ANSWER,
|
||||
BuiltinNodeTypes.END,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseSession:
|
||||
"""
|
||||
@@ -48,9 +33,10 @@ class ResponseSession:
|
||||
"""
|
||||
Create a ResponseSession from a response-capable node.
|
||||
|
||||
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
|
||||
At runtime this must be a node whose `node_type` is listed in `RESPONSE_SESSION_NODE_TYPES`
|
||||
and which implements `get_streaming_template()`.
|
||||
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
|
||||
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
|
||||
- `id: str`
|
||||
- `get_streaming_template() -> Template`
|
||||
|
||||
Args:
|
||||
node: Node from the materialized workflow graph.
|
||||
@@ -61,22 +47,11 @@ class ResponseSession:
|
||||
Raises:
|
||||
TypeError: If node is not a supported response node type.
|
||||
"""
|
||||
if node.node_type not in RESPONSE_SESSION_NODE_TYPES:
|
||||
supported_node_types = ", ".join(RESPONSE_SESSION_NODE_TYPES)
|
||||
raise TypeError(
|
||||
"ResponseSession.from_node only supports node types in "
|
||||
f"RESPONSE_SESSION_NODE_TYPES: {supported_node_types}"
|
||||
)
|
||||
|
||||
response_node = cast(_ResponseSessionNodeProtocol, node)
|
||||
try:
|
||||
template = response_node.get_streaming_template()
|
||||
except AttributeError as exc:
|
||||
raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc
|
||||
|
||||
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
|
||||
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=template,
|
||||
template=node.get_streaming_template(),
|
||||
)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@@ -12,8 +13,8 @@ from .base import GraphNodeEventBase
|
||||
class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
predecessor_node_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
provider_type: str = ""
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
from dify_graph.file import File
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
@@ -13,7 +13,7 @@ from .base import NodeEventBase
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(NodeEventBase):
|
||||
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
context_files: list[File] | None = Field(default=None, description="context files")
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
__all__ = ["BuiltinNodeTypes"]
|
||||
__all__ = ["NodeType"]
|
||||
|
||||
3
api/dify_graph/nodes/agent/__init__.py
Normal file
3
api/dify_graph/nodes/agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent_node import AgentNode
|
||||
|
||||
__all__ = ["AgentNode"]
|
||||
761
api/dify_graph/nodes/agent/agent_node.py
Normal file
761
api/dify_graph/nodes/agent/agent_node.py
Normal file
@@ -0,0 +1,761 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.segments import ArrayFileSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentNodeError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
node_type = NodeType.AGENT
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
credentials = self._generate_credentials(parameters=parameters)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_agent_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (AgentNodeData): The data associated with the agent node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
dify_ctx = self.require_dify_context()
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
entity,
|
||||
dify_ctx.invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self._fetch_model(value)
|
||||
# memory config
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
# remove structured output feature to support old version agent plugin
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> InvokeCredentials:
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
# generate credentials for tools selector
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if tool.get("credential_id"):
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
except ValidationError:
|
||||
continue
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
result: dict[str, Any] = {}
|
||||
typed_node_data = node_data
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def agent_strategy_icon(self) -> str | None:
|
||||
"""
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
dify_ctx = self.require_dify_context()
|
||||
plugins = manager.list_plugins(dify_ctx.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
msg_metadata = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
@@ -6,14 +6,14 @@ from pydantic import BaseModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = BuiltinNodeTypes.AGENT
|
||||
agent_strategy_provider_name: str
|
||||
type: NodeType = NodeType.AGENT
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str
|
||||
agent_strategy_label: str # redundancy
|
||||
memory: MemoryConfig | None = None
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
@@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMaxIterationError(AgentNodeError):
|
||||
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.answer.entities import AnswerNodeData
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@@ -11,7 +11,7 @@ from dify_graph.variables import ArrayFileSegment, FileSegment, Segment
|
||||
|
||||
|
||||
class AnswerNode(Node[AnswerNodeData]):
|
||||
node_type = BuiltinNodeTypes.ANSWER
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import StrEnum, auto
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
@@ -12,7 +12,7 @@ class AnswerNodeData(BaseNodeData):
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.ANSWER
|
||||
type: NodeType = NodeType.ANSWER
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import operator
|
||||
import pkgutil
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
@@ -9,7 +11,7 @@ from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
@@ -159,7 +161,7 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
node_type = BuiltinNodeTypes.CODE
|
||||
node_type = NodeType.CODE
|
||||
# No need to implement _get_title, _get_error_strategy, etc.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -177,8 +179,7 @@ class Node(Generic[NodeDataT]):
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under the
|
||||
# canonical workflow namespaces.
|
||||
# Only register production node implementations defined under dify_graph.nodes.*
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
@@ -186,7 +187,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_type = cls.node_type
|
||||
version = cls.version()
|
||||
bucket = Node._registry.setdefault(node_type, {})
|
||||
if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")):
|
||||
if module_name.startswith("dify_graph.nodes."):
|
||||
# Production node definitions take precedence and may override
|
||||
bucket[version] = cls # type: ignore[index]
|
||||
else:
|
||||
@@ -202,7 +203,6 @@ class Node(Generic[NodeDataT]):
|
||||
else:
|
||||
latest_key = max(version_keys) if version_keys else version
|
||||
bucket["latest"] = bucket[latest_key]
|
||||
Node._registry_version += 1
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
@@ -237,11 +237,6 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
|
||||
_registry_version: ClassVar[int] = 0
|
||||
|
||||
@classmethod
|
||||
def get_registry_version(cls) -> int:
|
||||
return cls._registry_version
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -354,10 +349,6 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
|
||||
"""Allow subclasses to enrich the started event without cross-node imports in the base class."""
|
||||
_ = event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@@ -371,10 +362,39 @@ class Node(Generic[NodeDataT]):
|
||||
in_iteration_id=None,
|
||||
start_at=self._start_at,
|
||||
)
|
||||
try:
|
||||
self.populate_start_event(start_event)
|
||||
except Exception:
|
||||
logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True)
|
||||
|
||||
# === FIXME(-LAN-): Needs to refactor.
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.node_data, "plugin_id", "")
|
||||
provider_name = getattr(self.node_data, "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.agent.agent_node import AgentNode
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData
|
||||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
# ===
|
||||
yield start_event
|
||||
|
||||
try:
|
||||
@@ -493,10 +513,31 @@ class Node(Generic[NodeDataT]):
|
||||
@abstractmethod
|
||||
def version(cls) -> str:
|
||||
"""`node_version` returns the version of current node type."""
|
||||
# NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so
|
||||
# workflow-layer node resolution can resolve numeric versions and `latest`.
|
||||
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||
#
|
||||
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||
# in `api/dify_graph/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under dify_graph.nodes so subclasses register themselves on import.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
import dify_graph.nodes as _nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports.
|
||||
if _modname == "dify_graph.nodes.node_mapping":
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
|
||||
# Return a readonly view so callers can't mutate the registry by accident
|
||||
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
@@ -770,16 +811,11 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
|
||||
retriever_resources = [
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
]
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
retriever_resources=retriever_resources,
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
node_version=self.version(),
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData
|
||||
@@ -72,7 +72,7 @@ _DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = {
|
||||
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = BuiltinNodeTypes.CODE
|
||||
node_type = NodeType.CODE
|
||||
_limits: CodeNodeLimits
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Literal
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@@ -40,7 +40,7 @@ class CodeNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.CODE
|
||||
type: NodeType = NodeType.CODE
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
|
||||
3
api/dify_graph/nodes/datasource/__init__.py
Normal file
3
api/dify_graph/nodes/datasource/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .datasource_node import DatasourceNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
||||
@@ -1,17 +1,22 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.repositories.datasource_manager_protocol import (
|
||||
DatasourceManagerProtocol,
|
||||
DatasourceParameter,
|
||||
OnlineDriveDownloadFileParam,
|
||||
)
|
||||
|
||||
from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -24,7 +29,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
Datasource Node
|
||||
"""
|
||||
|
||||
node_type = BuiltinNodeTypes.DATASOURCE
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def __init__(
|
||||
@@ -33,6 +38,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -40,11 +46,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.datasource_manager = DatasourceManager
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}"
|
||||
event.provider_type = self.node_data.provider_type
|
||||
self.datasource_manager = datasource_manager
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
@@ -17,7 +17,7 @@ class DatasourceEntity(BaseModel):
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
type: NodeType = BuiltinNodeTypes.DATASOURCE
|
||||
type: NodeType = NodeType.DATASOURCE
|
||||
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
@@ -42,14 +42,3 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||
|
||||
|
||||
class DatasourceParameter(BaseModel):
|
||||
workspace_id: str
|
||||
page_id: str
|
||||
type: str
|
||||
|
||||
|
||||
class OnlineDriveDownloadFileParam(BaseModel):
|
||||
id: str
|
||||
bucket: str
|
||||
@@ -2,11 +2,11 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR
|
||||
type: NodeType = NodeType.DOCUMENT_EXTRACTOR
|
||||
variable_selector: Sequence[str]
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod, file_manager
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@@ -46,7 +46,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
Supports plain text, PDF, and DOC/DOCX files.
|
||||
"""
|
||||
|
||||
node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR
|
||||
node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.template import Template
|
||||
@@ -6,7 +6,7 @@ from dify_graph.nodes.end.entities import EndNodeData
|
||||
|
||||
|
||||
class EndNode(Node[EndNodeData]):
|
||||
node_type = BuiltinNodeTypes.END
|
||||
node_type = NodeType.END
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ class EndNodeData(BaseNodeData):
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.END
|
||||
type: NodeType = NodeType.END
|
||||
outputs: list[OutputVariableEntity]
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
|
||||
|
||||
@@ -90,7 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.HTTP_REQUEST
|
||||
type: NodeType = NodeType.HTTP_REQUEST
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import variable_template_parser
|
||||
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
node_type = BuiltinNodeTypes.HTTP_REQUEST
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.consts import SELECTORS_LENGTH
|
||||
@@ -215,7 +215,7 @@ class UserAction(BaseModel):
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Human Input node data."""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.HUMAN_INPUT
|
||||
type: NodeType = NodeType.HUMAN_INPUT
|
||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
||||
form_content: str = ""
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
@@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputNode(Node[HumanInputNodeData]):
|
||||
node_type = BuiltinNodeTypes.HUMAN_INPUT
|
||||
node_type = NodeType.HUMAN_INPUT
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class IfElseNodeData(BaseNodeData):
|
||||
If Else Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.IF_ELSE
|
||||
type: NodeType = NodeType.IF_ELSE
|
||||
|
||||
class Case(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.if_else.entities import IfElseNodeData
|
||||
@@ -13,7 +13,7 @@ from dify_graph.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class IfElseNode(Node[IfElseNodeData]):
|
||||
node_type = BuiltinNodeTypes.IF_ELSE
|
||||
node_type = NodeType.IF_ELSE
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
from pydantic import Field
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.ITERATION
|
||||
type: NodeType = NodeType.ITERATION
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
@@ -34,7 +34,7 @@ class IterationStartNodeData(BaseNodeData):
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.ITERATION_START
|
||||
type: NodeType = NodeType.ITERATION_START
|
||||
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
|
||||
@@ -9,8 +9,8 @@ from typing_extensions import TypeIs
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
@@ -62,7 +62,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
Iteration Node.
|
||||
"""
|
||||
|
||||
node_type = BuiltinNodeTypes.ITERATION
|
||||
node_type = NodeType.ITERATION
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
|
||||
@classmethod
|
||||
@@ -486,16 +486,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
# TODO: Remove this workflow-layer import once node-class lookup no longer depends on node_factory.
|
||||
from core.workflow.node_factory import get_node_type_classes_mapping
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
if node_type not in node_mapping:
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = node_mapping[node_type][node_version]
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
@@ -564,7 +562,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
|
||||
current_index = index_variable.value
|
||||
for event in rst:
|
||||
if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START:
|
||||
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
|
||||
continue
|
||||
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.iteration.entities import IterationStartNodeData
|
||||
@@ -9,7 +9,7 @@ class IterationStartNode(Node[IterationStartNodeData]):
|
||||
Iteration Start Node.
|
||||
"""
|
||||
|
||||
node_type = BuiltinNodeTypes.ITERATION_START
|
||||
node_type = NodeType.ITERATION_START
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
||||
3
api/dify_graph/nodes/knowledge_index/__init__.py
Normal file
3
api/dify_graph/nodes/knowledge_index/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .knowledge_index_node import KnowledgeIndexNode
|
||||
|
||||
__all__ = ["KnowledgeIndexNode"]
|
||||
@@ -3,7 +3,6 @@ from typing import Literal, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
@@ -157,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
Knowledge index Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = KNOWLEDGE_INDEX_NODE_TYPE
|
||||
type: NodeType = NodeType.KNOWLEDGE_INDEX
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
@@ -2,15 +2,14 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, SystemVariableKey
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.template import Template
|
||||
from dify_graph.repositories.index_processor_protocol import IndexProcessorProtocol
|
||||
from dify_graph.repositories.summary_index_service_protocol import SummaryIndexServiceProtocol
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@@ -26,7 +25,7 @@ _INVOKE_FROM_DEBUGGER = "debugger"
|
||||
|
||||
|
||||
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
node_type = KNOWLEDGE_INDEX_NODE_TYPE
|
||||
node_type = NodeType.KNOWLEDGE_INDEX
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
def __init__(
|
||||
@@ -35,10 +34,12 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
index_processor: IndexProcessorProtocol,
|
||||
summary_index_service: SummaryIndexServiceProtocol,
|
||||
) -> None:
|
||||
super().__init__(id, config, graph_init_params, graph_runtime_state)
|
||||
self.index_processor = IndexProcessor()
|
||||
self.summary_index_service = SummaryIndex()
|
||||
self.index_processor = index_processor
|
||||
self.summary_index_service = summary_index_service
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = self.node_data
|
||||
3
api/dify_graph/nodes/knowledge_retrieval/__init__.py
Normal file
3
api/dify_graph/nodes/knowledge_retrieval/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
|
||||
__all__ = ["KnowledgeRetrievalNode"]
|
||||
@@ -4,7 +4,7 @@ from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
|
||||
type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
query_variable_selector: list[str] | None | str = None
|
||||
query_attachment_selector: list[str] | None | str = None
|
||||
dataset_ids: list[str]
|
||||
@@ -1,19 +1,12 @@
|
||||
"""Knowledge retrieval workflow node implementation.
|
||||
|
||||
This node now lives under ``core.workflow.nodes`` and is discovered directly by
|
||||
the workflow node registry.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
@@ -22,6 +15,7 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
@@ -38,7 +32,6 @@ from .exc import (
|
||||
KnowledgeRetrievalNodeError,
|
||||
RateLimitExceededError,
|
||||
)
|
||||
from .retrieval import KnowledgeRetrievalRequest, Source
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.file.models import File
|
||||
@@ -48,7 +41,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
@@ -60,6 +53,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -69,7 +63,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user