mirror of
https://github.com/langgenius/dify.git
synced 2026-03-14 03:37:02 +00:00
Compare commits
45 Commits
review-mys
...
refactor/w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62b40b888c | ||
|
|
5dda0eff0e | ||
|
|
ff824917d5 | ||
|
|
5e07cb4b0f | ||
|
|
573b4e41cb | ||
|
|
194c205ed3 | ||
|
|
7e1dc3c122 | ||
|
|
4203647c32 | ||
|
|
20e91990bf | ||
|
|
f38e8cca52 | ||
|
|
00eda73ad1 | ||
|
|
8b40a89add | ||
|
|
97776eabff | ||
|
|
fe561ef3d0 | ||
|
|
1104d35bbb | ||
|
|
724eaee77e | ||
|
|
4717168fe2 | ||
|
|
7fd3bd81ab | ||
|
|
0dcfac5b84 | ||
|
|
b66097b5f3 | ||
|
|
ceaa399351 | ||
|
|
dc50e4c4f2 | ||
|
|
157208ab1e | ||
|
|
3dabdc8282 | ||
|
|
ed5511ce28 | ||
|
|
68982f910e | ||
|
|
c43307dae1 | ||
|
|
b44b37518a | ||
|
|
b170eabaf3 | ||
|
|
e99628b76f | ||
|
|
60fe5e7f00 | ||
|
|
245f6b824d | ||
|
|
7d2054d4f4 | ||
|
|
07e19c0748 | ||
|
|
135b3a15a6 | ||
|
|
0045e387f5 | ||
|
|
44713a5c0f | ||
|
|
d5724aebde | ||
|
|
c59685748c | ||
|
|
36c1f4d506 | ||
|
|
31eba65fe0 | ||
|
|
72496a5847 | ||
|
|
8b16030d6b | ||
|
|
989db0e584 | ||
|
|
a0f0c97133 |
34
.github/actions/setup-web/action.yml
vendored
34
.github/actions/setup-web/action.yml
vendored
@@ -1,33 +1,13 @@
|
||||
name: Setup Web Environment
|
||||
description: Setup pnpm, Node.js, and install web dependencies.
|
||||
|
||||
inputs:
|
||||
node-version:
|
||||
description: Node.js version to use
|
||||
required: false
|
||||
default: "22"
|
||||
install-dependencies:
|
||||
description: Whether to install web dependencies after setting up Node.js
|
||||
required: false
|
||||
default: "true"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: ${{ inputs.node-version }}
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: ${{ inputs.install-dependencies == 'true' }}
|
||||
shell: bash
|
||||
run: pnpm --dir web install --frozen-lockfile
|
||||
node-version-file: "./web/.nvmrc"
|
||||
cache: true
|
||||
run-install: |
|
||||
- cwd: ./web
|
||||
args: ['--frozen-lockfile']
|
||||
|
||||
4
.github/workflows/autofix.yml
vendored
4
.github/workflows/autofix.yml
vendored
@@ -102,13 +102,11 @@ jobs:
|
||||
- name: Setup web environment
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
node-version: "24"
|
||||
|
||||
- name: ESLint autofix
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd web
|
||||
pnpm eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
|
||||
3
.github/workflows/main-ci.yml
vendored
3
.github/workflows/main-ci.yml
vendored
@@ -62,6 +62,9 @@ jobs:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
with:
|
||||
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
|
||||
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
|
||||
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpm run lint:ci
|
||||
vp run lint:ci
|
||||
# pnpm run lint:report
|
||||
# continue-on-error: true
|
||||
|
||||
@@ -102,17 +102,17 @@ jobs:
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run lint:tss
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
run: vp run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run knip
|
||||
run: vp run knip
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@@ -50,8 +50,6 @@ jobs:
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
install-dependencies: "false"
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect_changes
|
||||
|
||||
24
.github/workflows/web-tests.yml
vendored
24
.github/workflows/web-tests.yml
vendored
@@ -2,6 +2,13 @@ name: Web Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
base_sha:
|
||||
required: false
|
||||
type: string
|
||||
head_sha:
|
||||
required: false
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -14,6 +21,8 @@ jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -34,7 +43,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
@@ -50,6 +59,8 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -59,6 +70,7 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
@@ -72,7 +84,13 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: pnpm vitest --merge-reports --coverage --silent=passed-only
|
||||
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
|
||||
|
||||
- name: Check app/components diff coverage
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/check-components-diff-coverage.mjs
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
@@ -429,4 +447,4 @@ jobs:
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run build
|
||||
run: vp run build
|
||||
|
||||
@@ -188,7 +188,6 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ forbidden_modules =
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
@@ -90,9 +89,6 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> core.model_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.provider_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
dify_graph.nodes.llm.protocols -> core.model_manager
|
||||
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -100,8 +96,6 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.entities
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
@@ -110,12 +104,10 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
@@ -126,15 +118,11 @@ ignore_imports =
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
dify_graph.nodes.llm.node -> models.dataset
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.signature
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
dify_graph.nodes.tool.tool_node -> services
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
|
||||
@@ -17,11 +17,6 @@ class WeaviateConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENABLED: bool = Field(
|
||||
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Literal
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -169,6 +170,20 @@ register_enum_models(
|
||||
)
|
||||
|
||||
|
||||
def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
|
||||
"""
|
||||
Read the uploaded file and validate its actual size before delegating to the plugin service.
|
||||
|
||||
FileStorage.content_length is not reliable for multipart test uploads and may be zero even when
|
||||
content exists, so the controllers validate against the loaded bytes instead.
|
||||
"""
|
||||
content = file.read()
|
||||
if len(content) > max_size:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@@ -284,12 +299,7 @@ class PluginUploadFromPkgApi(Resource):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
file = request.files["pkg"]
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
try:
|
||||
response = PluginService.upload_pkg(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
@@ -328,12 +338,7 @@ class PluginUploadFromBundleApi(Resource):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
file = request.files["bundle"]
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE)
|
||||
try:
|
||||
response = PluginService.upload_bundle(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
|
||||
@@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
@@ -22,7 +23,6 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
9
api/core/agent/errors.py
Normal file
9
api/core/agent/errors.py
Normal file
@@ -0,0 +1,9 @@
|
||||
class AgentMaxIterationError(Exception):
|
||||
"""Raised when an agent runner exceeds the configured max iteration count."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
@@ -5,6 +5,7 @@ from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
@@ -25,7 +26,6 @@ from dify_graph.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
|
||||
@@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -3,7 +3,10 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@@ -30,8 +33,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -62,8 +67,6 @@ from dify_graph.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.graph_events.graph import GraphRunAbortedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
@@ -303,10 +306,12 @@ class WorkflowBasedAppRunner:
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
@@ -334,6 +339,18 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
@staticmethod
|
||||
def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None:
|
||||
raw_agent_strategy = event.extras.get("agent_strategy")
|
||||
if raw_agent_strategy is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return AgentStrategyInfo.model_validate(raw_agent_strategy)
|
||||
except ValidationError:
|
||||
logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True)
|
||||
return None
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
@@ -419,7 +436,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
agent_strategy=self._build_agent_strategy_info(event),
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .agent_strategy import AgentStrategyInfo
|
||||
|
||||
__all__ = ["AgentStrategyInfo"]
|
||||
|
||||
8
api/core/app/entities/agent_strategy.py
Normal file
8
api/core/app/entities/agent_strategy.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class AgentStrategyInfo(BaseModel):
|
||||
name: str
|
||||
icon: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
@@ -314,7 +314,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
start_at: datetime
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
agent_strategy: AgentStrategyInfo | None = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
@@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
agent_strategy: AgentStrategyInfo | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
|
||||
@@ -193,7 +193,8 @@ class LLMGenerator:
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
@@ -279,7 +280,8 @@ class LLMGenerator:
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
error_step = "handle unexpected exception"
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
|
||||
@@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
:param credential_type: the type of the credential, as CredentialType or str; str values
|
||||
are normalized via CredentialType.of and may raise ValueError for invalid values.
|
||||
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
|
||||
empty list for CredentialType.UNAUTHORIZED or missing schemas.
|
||||
|
||||
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
|
||||
Raises ValueError for invalid credential types.
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
if isinstance(credential_type, str):
|
||||
credential_type = CredentialType.of(credential_type)
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
|
||||
@@ -137,6 +137,7 @@ class ToolFileManager:
|
||||
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.trigger.debug.events import (
|
||||
build_plugin_pool_key,
|
||||
build_webhook_pool_key,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
@@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC):
|
||||
app_id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
node_config: Mapping[str, Any]
|
||||
node_config: NodeConfigDict
|
||||
node_id: str
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str):
|
||||
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.app_id = app_id
|
||||
@@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
def poll(self) -> TriggerDebugEvent | None:
|
||||
from services.trigger.trigger_service import TriggerService
|
||||
|
||||
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {}))
|
||||
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True)
|
||||
provider_id = TriggerProviderID(plugin_trigger_data.provider_id)
|
||||
pool_key: str = build_plugin_pool_key(
|
||||
name=plugin_trigger_data.event_name,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -22,7 +22,15 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
|
||||
from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
PluginAgentStrategyPresentationProvider,
|
||||
PluginAgentStrategyResolver,
|
||||
)
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
@@ -31,26 +39,18 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
|
||||
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from dify_graph.nodes.code.entities import CodeLanguage
|
||||
from dify_graph.nodes.code.limits import CodeNodeLimits
|
||||
from dify_graph.nodes.datasource import DatasourceNode
|
||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
|
||||
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
|
||||
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import build_http_request_config
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@@ -60,6 +60,9 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
@@ -100,10 +103,7 @@ class DefaultWorkflowCodeExecutor:
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Default implementation of NodeFactory that uses the traditional node mapping.
|
||||
|
||||
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
|
||||
and instantiating the appropriate node class.
|
||||
Default implementation of NodeFactory that resolves node classes from the live registry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -146,6 +146,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
|
||||
self._agent_strategy_resolver = PluginAgentStrategyResolver()
|
||||
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
|
||||
self._agent_runtime_support = AgentRuntimeSupport()
|
||||
self._agent_message_transformer = AgentMessageTransformer()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
|
||||
@@ -157,178 +161,125 @@ class DifyNodeFactory(NodeFactory):
|
||||
return DifyRunContext.model_validate(raw_ctx)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
:raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config["id"]
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config["data"]
|
||||
try:
|
||||
node_type = NodeType(node_data["type"])
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_data['type']}")
|
||||
|
||||
# Get node class
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
# Create node instance
|
||||
if node_type == NodeType.CODE:
|
||||
return CodeNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
max_output_length=self._template_transform_max_output_length,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
return HttpRequestNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
http_request_config=self._http_request_config,
|
||||
http_client=self._http_request_http_client,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HUMAN_INPUT:
|
||||
return HumanInputNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_INDEX:
|
||||
return KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
index_processor=IndexProcessor(),
|
||||
summary_index_service=SummaryIndex(),
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
return DatasourceNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
datasource_manager=DatasourceManager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DOCUMENT_EXTRACTOR:
|
||||
return DocumentExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TOOL:
|
||||
return ToolNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
|
||||
)
|
||||
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
NodeType.CODE: lambda: {
|
||||
"code_executor": self._code_executor,
|
||||
"code_limits": self._code_limits,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: lambda: {
|
||||
"template_renderer": self._template_renderer,
|
||||
"max_output_length": self._template_transform_max_output_length,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: lambda: {
|
||||
"http_request_config": self._http_request_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
|
||||
"file_manager": self._http_request_file_manager,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: lambda: {
|
||||
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: lambda: {
|
||||
"index_processor": IndexProcessor(),
|
||||
"summary_index_service": SummaryIndex(),
|
||||
},
|
||||
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
NodeType.DATASOURCE: lambda: {
|
||||
"datasource_manager": DatasourceManager,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
|
||||
"rag_retrieval": self._rag_retrieval,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: lambda: {
|
||||
"unstructured_api_config": self._document_extractor_unstructured_api_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=False,
|
||||
),
|
||||
NodeType.TOOL: lambda: {
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
|
||||
},
|
||||
NodeType.AGENT: lambda: {
|
||||
"strategy_resolver": self._agent_strategy_resolver,
|
||||
"presentation_provider": self._agent_strategy_presentation_provider,
|
||||
"runtime_support": self._agent_runtime_support,
|
||||
"message_transformer": self._agent_message_transformer,
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
)
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
|
||||
node_data_model = ModelConfig.model_validate(node_data["model"])
|
||||
@staticmethod
|
||||
def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData:
|
||||
"""
|
||||
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
|
||||
"""
|
||||
return node_class.validate_node_data(node_data)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
def _build_llm_compatible_node_init_kwargs(
|
||||
self,
|
||||
*,
|
||||
node_class: type[Node],
|
||||
node_data: BaseNodeData,
|
||||
include_http_client: bool,
|
||||
) -> dict[str, object]:
|
||||
validated_node_data = cast(
|
||||
LLMCompatibleNodeData,
|
||||
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
|
||||
)
|
||||
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
|
||||
node_init_kwargs: dict[str, object] = {
|
||||
"credentials_provider": self._llm_credentials_provider,
|
||||
"model_factory": self._llm_model_factory,
|
||||
"model_instance": model_instance,
|
||||
"memory": self._build_memory_for_llm_node(
|
||||
node_data=validated_node_data,
|
||||
model_instance=model_instance,
|
||||
),
|
||||
}
|
||||
if include_http_client:
|
||||
node_init_kwargs["http_client"] = self._http_request_http_client
|
||||
return node_init_kwargs
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
|
||||
node_data_model = node_data.model
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
@@ -364,14 +315,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
def _build_memory_for_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LLMCompatibleNodeData,
|
||||
model_instance: ModelInstance,
|
||||
) -> PromptMessageMemory | None:
|
||||
raw_memory_config = node_data.get("memory")
|
||||
if raw_memory_config is None:
|
||||
if node_data.memory is None:
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
@@ -381,6 +330,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
node_data_memory=node_memory,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
42
api/core/workflow/node_resolution.py
Normal file
42
api/core/workflow/node_resolution.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from importlib import import_module
|
||||
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, get_node_type_classes_mapping
|
||||
|
||||
_WORKFLOW_NODE_MODULES = ("core.workflow.nodes.agent",)
|
||||
_workflow_nodes_registered = False
|
||||
|
||||
|
||||
def ensure_workflow_nodes_registered() -> None:
|
||||
"""Import workflow-local node modules so they can register with `Node.__init_subclass__`."""
|
||||
global _workflow_nodes_registered
|
||||
|
||||
if _workflow_nodes_registered:
|
||||
return
|
||||
|
||||
for module_name in _WORKFLOW_NODE_MODULES:
|
||||
import_module(module_name)
|
||||
|
||||
_workflow_nodes_registered = True
|
||||
|
||||
|
||||
def get_workflow_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
ensure_workflow_nodes_registered()
|
||||
return get_node_type_classes_mapping()
|
||||
|
||||
|
||||
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
node_mapping = get_workflow_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
4
api/core/workflow/nodes/agent/__init__.py
Normal file
4
api/core/workflow/nodes/agent/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .agent_node import AgentNode
|
||||
from .entities import AgentNodeData
|
||||
|
||||
__all__ = ["AgentNode", "AgentNodeData"]
|
||||
188
api/core/workflow/nodes/agent/agent_node.py
Normal file
188
api/core/workflow/nodes/agent/agent_node.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, SystemVariableKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from .entities import AgentNodeData
|
||||
from .exceptions import (
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
)
|
||||
from .message_transformer import AgentMessageTransformer
|
||||
from .runtime_support import AgentRuntimeSupport
|
||||
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
node_type = NodeType.AGENT
|
||||
|
||||
_strategy_resolver: AgentStrategyResolver
|
||||
_presentation_provider: AgentStrategyPresentationProvider
|
||||
_runtime_support: AgentRuntimeSupport
|
||||
_message_transformer: AgentMessageTransformer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
strategy_resolver: AgentStrategyResolver,
|
||||
presentation_provider: AgentStrategyPresentationProvider,
|
||||
runtime_support: AgentRuntimeSupport,
|
||||
message_transformer: AgentMessageTransformer,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._strategy_resolver = strategy_resolver
|
||||
self._presentation_provider = presentation_provider
|
||||
self._runtime_support = runtime_support
|
||||
self._message_transformer = message_transformer
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
event.extras["agent_strategy"] = {
|
||||
"name": self.node_data.agent_strategy_name,
|
||||
"icon": self._presentation_provider.get_icon(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
),
|
||||
}
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = self._strategy_resolver.resolve(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
parameters = self._runtime_support.build_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
invoke_from=dify_ctx.invoke_from,
|
||||
)
|
||||
parameters_for_log = self._runtime_support.build_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
invoke_from=dify_ctx.invoke_from,
|
||||
for_log=True,
|
||||
)
|
||||
credentials = self._runtime_support.build_credentials(parameters=parameters)
|
||||
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
yield from self._message_transformer.transform(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self._presentation_provider.get_icon(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
),
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
result: dict[str, Any] = {}
|
||||
typed_node_data = node_data
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
@@ -5,13 +5,15 @@ from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
type: NodeType = NodeType.AGENT
|
||||
agent_strategy_provider_name: str
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
agent_strategy_label: str
|
||||
memory: MemoryConfig | None = None
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
@@ -119,14 +119,3 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMaxIterationError(AgentNodeError):
|
||||
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
292
api/core/workflow/nodes/agent/message_transformer.py
Normal file
292
api/core/workflow/nodes/agent/message_transformer.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayFileSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
|
||||
|
||||
|
||||
class AgentMessageTransformer:
|
||||
def transform(
|
||||
self,
|
||||
*,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
40
api/core/workflow/nodes/agent/plugin_strategy_adapter.py
Normal file
40
api/core/workflow/nodes/agent/plugin_strategy_adapter.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
|
||||
|
||||
|
||||
class PluginAgentStrategyResolver(AgentStrategyResolver):
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
agent_strategy_provider_name: str,
|
||||
agent_strategy_name: str,
|
||||
) -> ResolvedAgentStrategy:
|
||||
return get_plugin_agent_strategy(
|
||||
tenant_id=tenant_id,
|
||||
agent_strategy_provider_name=agent_strategy_provider_name,
|
||||
agent_strategy_name=agent_strategy_name,
|
||||
)
|
||||
|
||||
|
||||
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
|
||||
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
try:
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name
|
||||
)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
return current_plugin.declaration.icon
|
||||
276
api/core/workflow/nodes/agent/runtime_support.py
Normal file
276
api/core/workflow/nodes/agent/runtime_support.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from dify_graph.enums import SystemVariableKey
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
||||
from .strategy_protocols import ResolvedAgentStrategy
|
||||
|
||||
|
||||
class AgentRuntimeSupport:
|
||||
def build_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
invoke_from: Any,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id,
|
||||
app_id,
|
||||
entity,
|
||||
invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
|
||||
credentials = InvokeCredentials()
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if not tool.get("credential_id"):
|
||||
continue
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
except ValidationError:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
return credentials
|
||||
|
||||
def fetch_memory(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
app_id: str,
|
||||
model_instance: ModelInstance,
|
||||
) -> TokenBufferMemory | None:
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
@staticmethod
|
||||
def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]:
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value)
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
@staticmethod
|
||||
def _filter_mcp_type_tool(
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
39
api/core/workflow/nodes/agent/strategy_protocols.py
Normal file
39
api/core/workflow/nodes/agent/strategy_protocols.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class ResolvedAgentStrategy(Protocol):
|
||||
meta_version: str | None
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]: ...
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
credentials: InvokeCredentials | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]: ...
|
||||
|
||||
|
||||
class AgentStrategyResolver(Protocol):
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
agent_strategy_provider_name: str,
|
||||
agent_strategy_name: str,
|
||||
) -> ResolvedAgentStrategy: ...
|
||||
|
||||
|
||||
class AgentStrategyPresentationProvider(Protocol):
|
||||
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ...
|
||||
@@ -9,9 +9,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph import Graph
|
||||
@@ -23,7 +24,6 @@ from dify_graph.graph_engine.protocols.command_channel import CommandChannel
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
@@ -212,7 +212,7 @@ class WorkflowEntry:
|
||||
node_config_data = node_config["data"]
|
||||
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data["type"])
|
||||
node_type = node_config_data.type
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -234,8 +234,7 @@ class WorkflowEntry:
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
typed_node_config = cast(dict[str, object], node_config)
|
||||
node = cast(Any, node_factory).create_node(typed_node_config)
|
||||
node = node_factory.create_node(node_config)
|
||||
node_cls = type(node)
|
||||
|
||||
try:
|
||||
@@ -344,7 +343,7 @@ class WorkflowEntry:
|
||||
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
|
||||
node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1")
|
||||
if not node_cls:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
|
||||
@@ -371,10 +370,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config: NodeConfigDict = {
|
||||
"id": node_id,
|
||||
"data": cast(NodeConfigData, node_data),
|
||||
}
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_start_reason import WorkflowStartReason
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
"""Agent node strategy initialization data."""
|
||||
|
||||
name: str
|
||||
icon: str | None = None
|
||||
176
api/dify_graph/entities/base_node_data.py
Normal file
176
api/dify_graph/entities/base_node_data.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from dify_graph.entities.exc import DefaultValueTypeError
|
||||
from dify_graph.enums import ErrorStrategy, NodeType
|
||||
|
||||
# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`.
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
|
||||
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
|
||||
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
|
||||
# and persisted templates/workflows also carry undeclared compatibility keys such as
|
||||
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
|
||||
# here until graph parsing becomes discriminated by node type or those legacy payloads
|
||||
# are normalized.
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
type: NodeType
|
||||
title: str = ""
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = Field(default_factory=RetryConfig)
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""
|
||||
Dict-style access without calling model_dump() on every lookup.
|
||||
Prefer using model fields and Pydantic's extra storage.
|
||||
"""
|
||||
# First, check declared model fields
|
||||
if key in self.__class__.model_fields:
|
||||
return getattr(self, key)
|
||||
|
||||
# Then, check undeclared compatibility fields stored in Pydantic's extra dict.
|
||||
extras = getattr(self, "__pydantic_extra__", None)
|
||||
if extras is None:
|
||||
extras = getattr(self, "model_extra", None)
|
||||
if extras is not None and key in extras:
|
||||
return extras[key]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Dict-style .get() without calling model_dump() on every lookup.
|
||||
"""
|
||||
if key in self.__class__.model_fields:
|
||||
return getattr(self, key)
|
||||
|
||||
extras = getattr(self, "__pydantic_extra__", None)
|
||||
if extras is None:
|
||||
extras = getattr(self, "model_extra", None)
|
||||
if extras is not None and key in extras:
|
||||
return extras.get(key, default)
|
||||
|
||||
return default
|
||||
@@ -4,21 +4,20 @@ import sys
|
||||
|
||||
from pydantic import TypeAdapter, with_config
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigData(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
data: NodeConfigData
|
||||
# This is the permissive raw graph boundary. Node factories re-validate `data`
|
||||
# with the concrete `NodeData` subtype after resolving the node implementation.
|
||||
data: BaseNodeData
|
||||
|
||||
|
||||
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Protocol, cast, final
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from libs.typing import is_str
|
||||
|
||||
@@ -34,7 +34,8 @@ class NodeFactory(Protocol):
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
:raises ValueError: if node type is unknown or no implementation exists for the resolved version
|
||||
:raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -115,10 +116,7 @@ class Graph:
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid]["data"]
|
||||
node_type = node_data["type"]
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if NodeType(node_type).is_start_node:
|
||||
if node_data.type.is_start_node:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
@@ -203,6 +201,23 @@ class Graph:
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@staticmethod
|
||||
def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Remove editor-only nodes before `NodeConfigDict` validation.
|
||||
|
||||
Persisted note widgets use a top-level `type == "custom-note"` but leave
|
||||
`data.type` empty because they are never executable graph nodes. Filter
|
||||
them while configs are still raw dicts so Pydantic does not validate
|
||||
their placeholder payloads against `BaseNodeData.type: NodeType`.
|
||||
"""
|
||||
filtered_node_configs: list[dict[str, object]] = []
|
||||
for node_config in node_configs:
|
||||
if node_config.get("type", "") == "custom-note":
|
||||
continue
|
||||
filtered_node_configs.append(dict(node_config))
|
||||
return filtered_node_configs
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
@@ -302,13 +317,13 @@ class Graph:
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
node_configs = cls._filter_canvas_only_nodes(node_configs)
|
||||
node_configs = _ListNodeConfigDict.validate_python(node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from datetime import datetime
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@@ -13,8 +12,8 @@ from .base import GraphNodeEventBase
|
||||
class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
predecessor_node_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
provider_type: str = ""
|
||||
|
||||
@@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
|
||||
@@ -4,7 +4,8 @@ class InvokeError(ValueError):
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
@@ -282,7 +282,8 @@ class ModelProviderFactory:
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
simple_provider_schema.models.extend(all_model_type_models)
|
||||
if model_type:
|
||||
simple_provider_schema.models = all_model_type_models
|
||||
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .agent_node import AgentNode
|
||||
|
||||
__all__ = ["AgentNode"]
|
||||
@@ -1,762 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.segments import ArrayFileSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentNodeError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
node_type = NodeType.AGENT
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
credentials = self._generate_credentials(parameters=parameters)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_agent_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (AgentNodeData): The data associated with the agent node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
dify_ctx = self.require_dify_context()
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
entity,
|
||||
dify_ctx.invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self._fetch_model(value)
|
||||
# memory config
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
# remove structured output feature to support old version agent plugin
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> InvokeCredentials:
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
# generate credentials for tools selector
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if tool.get("credential_id"):
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
except ValidationError:
|
||||
continue
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AgentNodeData.model_validate(node_data)
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def agent_strategy_icon(self) -> str | None:
|
||||
"""
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
dify_ctx = self.require_dify_context()
|
||||
plugins = manager.list_plugins(dify_ctx.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
msg_metadata = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
@@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: AnswerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AnswerNodeData.model_validate(node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
@@ -3,7 +3,8 @@ from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
@@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData):
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.ANSWER
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
@@ -6,6 +6,5 @@ __all__ = [
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
]
|
||||
|
||||
@@ -1,31 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from dify_graph.enums import ErrorStrategy
|
||||
|
||||
from .exc import DefaultValueTypeError
|
||||
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
@@ -76,112 +57,6 @@ class OutputVariableEntity(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
@@ -11,7 +11,9 @@ from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
@@ -62,8 +64,6 @@ from dify_graph.node_events import (
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
|
||||
@@ -153,11 +153,11 @@ class Node(Generic[NodeDataT]):
|
||||
Later, in __init__:
|
||||
::
|
||||
|
||||
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
config["data"] ──► _node_data_type.model_validate(..., from_attributes=True)
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
@@ -241,7 +241,7 @@ class Node(Generic[NodeDataT]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
@@ -254,22 +254,21 @@ class Node(Generic[NodeDataT]):
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
node_id = config["id"]
|
||||
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
self._node_data = self.validate_node_data(config["data"])
|
||||
|
||||
self.post_init()
|
||||
|
||||
@classmethod
|
||||
def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT:
|
||||
"""Validate shared graph node payloads against the subclass-declared NodeData model."""
|
||||
return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True))
|
||||
|
||||
def post_init(self) -> None:
|
||||
"""Optional hook for subclasses requiring extra initialization."""
|
||||
return
|
||||
@@ -342,9 +341,6 @@ class Node(Generic[NodeDataT]):
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
@@ -353,6 +349,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
|
||||
"""Allow subclasses to enrich the started event without cross-node imports in the base class."""
|
||||
_ = event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@@ -366,41 +366,10 @@ class Node(Generic[NodeDataT]):
|
||||
in_iteration_id=None,
|
||||
start_at=self._start_at,
|
||||
)
|
||||
|
||||
# === FIXME(-LAN-): Needs to refactor.
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.node_data, "plugin_id", "")
|
||||
provider_name = getattr(self.node_data, "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
from dify_graph.nodes.agent.agent_node import AgentNode
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData
|
||||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
# ===
|
||||
try:
|
||||
self.populate_start_event(start_event)
|
||||
except Exception:
|
||||
logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True)
|
||||
yield start_event
|
||||
|
||||
try:
|
||||
@@ -442,7 +411,7 @@ class Node(Generic[NodeDataT]):
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
@@ -480,13 +449,12 @@ class Node(Generic[NodeDataT]):
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
node_id = config["id"]
|
||||
node_data = cls.validate_node_data(config["data"])
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -496,7 +464,7 @@ class Node(Generic[NodeDataT]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: NodeDataT,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
@@ -520,10 +488,8 @@ class Node(Generic[NodeDataT]):
|
||||
@abstractmethod
|
||||
def version(cls) -> str:
|
||||
"""`node_version` returns the version of current node type."""
|
||||
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||
#
|
||||
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||
# in `api/dify_graph/nodes/__init__.py`.
|
||||
# NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so
|
||||
# `Node.get_node_type_classes_mapping()` can resolve numeric versions and `latest`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
@@ -531,7 +497,9 @@ class Node(Generic[NodeDataT]):
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under dify_graph.nodes so subclasses register themselves on import.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
Callers that rely on workflow-local nodes defined outside `dify_graph.nodes` must import
|
||||
those modules before invoking this method so they can register through `__init_subclass__`.
|
||||
We then return a readonly view of the registry to avoid accidental mutation.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
import dify_graph.nodes as _nodes_pkg
|
||||
|
||||
@@ -3,6 +3,7 @@ from decimal import Decimal
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: CodeNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.CODE
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
|
||||
@@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
@@ -47,6 +48,10 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
)
|
||||
self.datasource_manager = datasource_manager
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}"
|
||||
event.provider_type = self.node_data.provider_type
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
@@ -181,7 +186,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: DatasourceNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -190,11 +195,10 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
typed_node_data = DatasourceNodeData.model_validate(node_data)
|
||||
result = {}
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
input = node_data.datasource_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
@@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel):
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
type: NodeType = NodeType.DATASOURCE
|
||||
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.DOCUMENT_EXTRACTOR
|
||||
variable_selector: Sequence[str]
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod, file_manager
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@@ -54,7 +55,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -136,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: DocumentExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
|
||||
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
@@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData):
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.END
|
||||
outputs: list[OutputVariableEntity]
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ import charset_normalizer
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
|
||||
|
||||
@@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.HTTP_REQUEST
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
|
||||
@@ -3,6 +3,7 @@ import mimetypes
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -163,18 +164,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: HttpRequestNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = HttpRequestNodeData.model_validate(node_data)
|
||||
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
|
||||
if typed_node_data.body:
|
||||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
|
||||
if node_data.body:
|
||||
body_type = node_data.body.type
|
||||
data = node_data.body.data
|
||||
match body_type:
|
||||
case "none":
|
||||
pass
|
||||
|
||||
@@ -10,7 +10,8 @@ from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.consts import SELECTORS_LENGTH
|
||||
@@ -71,8 +72,8 @@ class EmailDeliveryConfig(BaseModel):
|
||||
body: str
|
||||
debug_mode: bool = False
|
||||
|
||||
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
|
||||
if not user_id:
|
||||
def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig":
|
||||
if user_id is None:
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
|
||||
@@ -140,7 +141,7 @@ def apply_debug_email_recipient(
|
||||
method: DeliveryChannelConfig,
|
||||
*,
|
||||
enabled: bool,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
) -> DeliveryChannelConfig:
|
||||
if not enabled:
|
||||
return method
|
||||
@@ -148,7 +149,7 @@ def apply_debug_email_recipient(
|
||||
return method
|
||||
if not method.config.debug_mode:
|
||||
return method
|
||||
debug_config = method.config.with_debug_recipient(user_id or "")
|
||||
debug_config = method.config.with_debug_recipient(user_id)
|
||||
return method.model_copy(update={"config": debug_config})
|
||||
|
||||
|
||||
@@ -214,6 +215,7 @@ class UserAction(BaseModel):
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Human Input node data."""
|
||||
|
||||
type: NodeType = NodeType.HUMAN_INPUT
|
||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
||||
form_content: str = ""
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import (
|
||||
@@ -63,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository,
|
||||
@@ -348,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: HumanInputNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input default values.
|
||||
@@ -357,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input default values
|
||||
"""
|
||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
return node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
|
||||
@@ -2,7 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
|
||||
|
||||
@@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData):
|
||||
If Else Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.IF_ELSE
|
||||
|
||||
class Case(BaseModel):
|
||||
"""
|
||||
Case entity representing a single logical condition group
|
||||
|
||||
@@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IfElseNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in typed_node_data.cases or []:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
for case in node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
class ErrorHandleMode(StrEnum):
|
||||
@@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.ITERATION
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
@@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData):
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.ITERATION_START
|
||||
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -460,21 +461,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: IterationNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IterationNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
node_config_data = node.get("data", {})
|
||||
if node_config_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
@@ -488,16 +486,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.node_mapping import get_node_type_classes_mapping
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
if node_type not in node_mapping:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = node_mapping[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Literal, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
@@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
Knowledge index Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-index"
|
||||
type: NodeType = NodeType.KNOWLEDGE_INDEX
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@@ -30,7 +31,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
index_processor: IndexProcessorProtocol,
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-retrieval"
|
||||
type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
query_variable_selector: list[str] | None | str = None
|
||||
query_attachment_selector: list[str] | None | str = None
|
||||
dataset_ids: list[str]
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@@ -49,7 +50,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
@@ -301,15 +302,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
if typed_node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
if typed_node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||
if node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
if node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector
|
||||
return variable_mapping
|
||||
|
||||
@@ -3,7 +3,8 @@ from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
@@ -62,6 +63,7 @@ class ExtractConfig(BaseModel):
|
||||
|
||||
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.LIST_OPERATOR
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderByConfig
|
||||
|
||||
@@ -4,8 +4,9 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
@@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.LLM
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
|
||||
@@ -21,6 +21,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -121,7 +122,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
@@ -954,14 +955,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LLMNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
prompt_template = node_data.prompt_template
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
@@ -979,7 +977,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
memory = typed_node_data.memory
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
@@ -987,16 +985,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if typed_node_data.context.enabled:
|
||||
variable_mapping["#context#"] = typed_node_data.context.variable_selector
|
||||
if node_data.context.enabled:
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if typed_node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
|
||||
if typed_node_data.memory:
|
||||
if node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
|
||||
|
||||
if typed_node_data.prompt_config:
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
@@ -1009,7 +1007,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
break
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@@ -39,6 +41,7 @@ class LoopVariableData(BaseModel):
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
type: NodeType = NodeType.LOOP
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
@@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData):
|
||||
Loop Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.LOOP_START
|
||||
|
||||
|
||||
class LoopEndNodeData(BaseNodeData):
|
||||
@@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData):
|
||||
Loop End Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.LOOP_END
|
||||
|
||||
|
||||
class LoopState(BaseLoopState):
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LoopNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
@@ -318,16 +316,18 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.node_mapping import get_node_type_classes_mapping
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
if node_type not in node_mapping:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = node_mapping[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
@@ -342,7 +342,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in typed_node_data.loop_variables or []:
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
|
||||
@@ -5,5 +5,24 @@ from dify_graph.nodes.base.node import Node
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks dify_graph.nodes
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
"""Return the live node registry after importing all `dify_graph.nodes` modules."""
|
||||
return Node.get_node_type_classes_mapping()
|
||||
|
||||
|
||||
def resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
node_mapping = get_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
|
||||
|
||||
# Snapshot kept for compatibility with older tests; production paths should use the live helpers.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping()
|
||||
|
||||
@@ -8,7 +8,8 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
Parameter Extractor Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.PARAMETER_EXTRACTOR
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -837,15 +838,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: ParameterExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
|
||||
|
||||
if typed_node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
|
||||
if node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
|
||||
for selector in selectors:
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@@ -11,6 +12,7 @@ class ClassConfig(BaseModel):
|
||||
|
||||
|
||||
class QuestionClassifierNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.QUESTION_CLASSIFIER
|
||||
query_variable_selector: list[str]
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.model_manager import ModelInstance
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -62,7 +63,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -251,16 +252,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: QuestionClassifierNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {"query": typed_node_data.query_variable_selector}
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors: list[VariableSelector] = []
|
||||
if typed_node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
|
||||
|
||||
@@ -2,7 +2,8 @@ from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
@@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData):
|
||||
Start Node Data
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.START
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
@@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData):
|
||||
Template Transform Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.TEMPLATE_TRANSFORM
|
||||
variables: list[VariableSelector]
|
||||
template: str
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@@ -25,7 +26,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -86,12 +87,9 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@ from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
@@ -32,6 +33,8 @@ class ToolEntity(BaseModel):
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
type: NodeType = NodeType.TOOL
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -46,7 +47,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -64,6 +65,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = self.node_data.provider_id
|
||||
event.provider_type = self.node_data.provider_type
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run the tool node
|
||||
@@ -484,7 +489,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: ToolNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -493,9 +498,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
typed_node_data = node_data
|
||||
result = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
|
||||
@@ -4,13 +4,16 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.trigger.entities.entities import EventParameter
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError
|
||||
|
||||
|
||||
class TriggerEventNodeData(BaseNodeData):
|
||||
"""Plugin trigger node data"""
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_PLUGIN
|
||||
|
||||
class TriggerEventInput(BaseModel):
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
@@ -38,8 +41,6 @@ class TriggerEventNodeData(BaseNodeData):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return type
|
||||
|
||||
title: str
|
||||
desc: str | None = None
|
||||
plugin_id: str = Field(..., description="Plugin ID")
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
event_name: str = Field(..., description="Event name")
|
||||
|
||||
@@ -32,6 +32,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = self.node_data.provider_id
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the plugin trigger node.
|
||||
|
||||
@@ -2,7 +2,8 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class TriggerScheduleNodeData(BaseNodeData):
|
||||
@@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData):
|
||||
Trigger Schedule Node Data
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_SCHEDULE
|
||||
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
|
||||
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
|
||||
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dify_graph.nodes.base.exc import BaseNodeError
|
||||
from dify_graph.entities.exc import BaseNodeError
|
||||
|
||||
|
||||
class ScheduleNodeError(BaseNodeError):
|
||||
|
||||
@@ -1,10 +1,41 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
}
|
||||
)
|
||||
|
||||
_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
}
|
||||
)
|
||||
|
||||
_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES
|
||||
|
||||
_WEBHOOK_BODY_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.FILE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Method(StrEnum):
|
||||
@@ -25,29 +56,34 @@ class ContentType(StrEnum):
|
||||
|
||||
|
||||
class WebhookParameter(BaseModel):
|
||||
"""Parameter definition for headers, query params, or body."""
|
||||
"""Parameter definition for headers or query params."""
|
||||
|
||||
name: str
|
||||
type: SegmentType = SegmentType.STRING
|
||||
required: bool = False
|
||||
|
||||
@field_validator("type", mode="after")
|
||||
@classmethod
|
||||
def validate_type(cls, v: SegmentType) -> SegmentType:
|
||||
if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook parameter type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class WebhookBodyParameter(BaseModel):
|
||||
"""Body parameter with type information."""
|
||||
|
||||
name: str
|
||||
type: Literal[
|
||||
"string",
|
||||
"number",
|
||||
"boolean",
|
||||
"object",
|
||||
"array[string]",
|
||||
"array[number]",
|
||||
"array[boolean]",
|
||||
"array[object]",
|
||||
"file",
|
||||
] = "string"
|
||||
type: SegmentType = SegmentType.STRING
|
||||
required: bool = False
|
||||
|
||||
@field_validator("type", mode="after")
|
||||
@classmethod
|
||||
def validate_type(cls, v: SegmentType) -> SegmentType:
|
||||
if v not in _WEBHOOK_BODY_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook body parameter type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class WebhookData(BaseNodeData):
|
||||
"""
|
||||
@@ -57,6 +93,7 @@ class WebhookData(BaseNodeData):
|
||||
class SyncMode(StrEnum):
|
||||
SYNC = "async" # only support
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_WEBHOOK
|
||||
method: Method = Method.GET
|
||||
content_type: ContentType = Field(default=ContentType.JSON)
|
||||
headers: Sequence[WebhookParameter] = Field(default_factory=list)
|
||||
@@ -71,6 +108,22 @@ class WebhookData(BaseNodeData):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
@field_validator("headers", mode="after")
|
||||
@classmethod
|
||||
def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
|
||||
for param in v:
|
||||
if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook header parameter type: {param.type}")
|
||||
return v
|
||||
|
||||
@field_validator("params", mode="after")
|
||||
@classmethod
|
||||
def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
|
||||
for param in v:
|
||||
if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook query parameter type: {param.type}")
|
||||
return v
|
||||
|
||||
status_code: int = 200 # Expected status code for response
|
||||
response_body: str = "" # Template for response body
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dify_graph.nodes.base.exc import BaseNodeError
|
||||
from dify_graph.entities.exc import BaseNodeError
|
||||
|
||||
|
||||
class WebhookNodeError(BaseNodeError):
|
||||
|
||||
@@ -152,7 +152,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs[param_name] = raw_data
|
||||
continue
|
||||
|
||||
if param_type == "file":
|
||||
if param_type == SegmentType.FILE:
|
||||
# Get File object (already processed by webhook controller)
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
|
||||
@@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData):
|
||||
Variable Aggregator Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.VARIABLE_AGGREGATOR
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: AdvancedSettings | None = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@@ -22,7 +23,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
@@ -52,21 +53,18 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: VariableAssignerData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerData.model_validate(node_data)
|
||||
|
||||
mapping = {}
|
||||
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
|
||||
assigned_variable_node_id = node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(typed_node_data.assigned_variable_selector)
|
||||
selector_key = ".".join(node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = typed_node_data.assigned_variable_selector
|
||||
mapping[key] = node_data.assigned_variable_selector
|
||||
|
||||
selector_key = ".".join(typed_node_data.input_variable_selector)
|
||||
selector_key = ".".join(node_data.input_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = typed_node_data.input_variable_selector
|
||||
mapping[key] = node_data.input_variable_selector
|
||||
return mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class WriteMode(StrEnum):
|
||||
@@ -11,6 +12,7 @@ class WriteMode(StrEnum):
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
type: NodeType = NodeType.VARIABLE_ASSIGNER
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user