Compare commits

...

45 Commits

Author SHA1 Message Date
WH-2099
62b40b888c test: avoid leaking tool parameter decrypt mock 2026-03-14 03:56:47 +08:00
WH-2099
5dda0eff0e refactor(workflow): remove unused workflow package re-export 2026-03-14 02:44:08 +08:00
WH-2099
ff824917d5 fix(api): validate plugin upload size using file content 2026-03-14 02:21:03 +08:00
WH-2099
5e07cb4b0f refactor(workflow): move agent node back to core workflow 2026-03-14 02:07:22 +08:00
非法操作
573b4e41cb fix: correctly detect required columns in archived workflow run restore (#33403)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-13 23:24:40 +08:00
yevanmore
194c205ed3 fix(api): allow punctuation in uploaded filenames (#33364)
Co-authored-by: Chen Yefan <cyefan2@gmail.com>
2026-03-13 21:33:09 +08:00
Coding On Star
7e1dc3c122 refactor(web): split share text-generation and add high-coverage tests (#33408)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-13 19:10:24 +08:00
Stephen Zhou
4203647c32 chore: use vite plus (#33407) 2026-03-13 18:18:44 +08:00
dependabot[bot]
20e91990bf chore(deps): bump orjson from 3.11.4 to 3.11.6 in /api (#33380)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-13 17:33:33 +09:00
Asuka Minato
f38e8cca52 test: [Refactor/Chore] use Testcontainers to do sql test #32454 (#32460) 2026-03-13 17:32:39 +09:00
Coding On Star
00eda73ad1 test: enforce app/components coverage gates in web tests (#33395)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-13 16:31:05 +08:00
Ethan T.
8b40a89add fix: with_debug_recipient() silently drops debug emails when user_id is None or empty string (#33373)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-03-13 15:35:02 +08:00
Joel
97776eabff chore: add tracking info of in site message (#33394)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-03-13 15:29:24 +08:00
yyh
fe561ef3d0 feat(workflow): add edge context menu with delete support (#33391) 2026-03-13 15:11:24 +08:00
lif
1104d35bbb chore: remove unused WEAVIATE_GRPC_ENABLED config option (#33378)
Signed-off-by: majiayu000 <1835304752@qq.com>
2026-03-13 15:06:50 +08:00
Stephen Zhou
724eaee77e chore: add dev proxy server, update deps (#33371) 2026-03-13 12:52:19 +08:00
Copilot
4717168fe2 chore(web): disable i18next support notice (#33309)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-03-12 18:18:10 +08:00
wangxiaolei
7fd3bd81ab fix: test_get_credentials_and_config_selects_plugin_id_and_key_api_ke… (#33361) 2026-03-12 19:09:46 +09:00
NFish
0dcfac5b84 fix: The HTTP Request node panel supports line break and overflow handling for variable values (#33359)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-12 17:57:08 +08:00
Stephen Zhou
b66097b5f3 chore: use tsconfigPaths for vinext (#33363) 2026-03-12 17:56:22 +08:00
Coding On Star
ceaa399351 test: refactor mock implementation in markdown component tests (#33350)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-12 15:50:07 +08:00
Rajat Agarwal
dc50e4c4f2 test: added test cases for core.workflow module (#33126) 2026-03-12 15:35:25 +08:00
mahammadasim
157208ab1e test: added test for services of ops, summary, vector, website and ji… (#32893)
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
2026-03-12 15:34:20 +08:00
mahammadasim
3dabdc8282 test: added tests for backend core.ops module (#32639)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
2026-03-12 15:33:15 +08:00
Saumya Talwani
ed5511ce28 test: improve coverage for some files (#33218) 2026-03-12 15:09:10 +08:00
Saumya Talwani
68982f910e test: improve coverage parameters for some files in base (#33207) 2026-03-12 14:57:31 +08:00
yyh
c43307dae1 refactor(switch): Base UI migration with loading/skeleton variants (#33345)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-12 14:40:43 +08:00
Stephen Zhou
b44b37518a chore: update vinext (#33347) 2026-03-12 13:18:11 +08:00
Rajat Agarwal
b170eabaf3 test: Unit test cases for core.tools module (#32447)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
Co-authored-by: mahammadasim <135003320+mahammadasim@users.noreply.github.com>
2026-03-12 11:48:13 +08:00
mahammadasim
e99628b76f test: added test for core token buffer memory and model runtime (#32512)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
2026-03-12 11:46:46 +08:00
mahammadasim
60fe5e7f00 test: added for core logging and core mcp (#32478)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
2026-03-12 11:44:56 +08:00
mahammadasim
245f6b824d test: add test for core extension, external_data_tool and llm generator (#32468) 2026-03-12 11:44:38 +08:00
Dev Sharma
7d2054d4f4 test: add UTs for api/services recommend_app, tools, workflow (#32645)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-12 11:37:03 +08:00
Rajat Agarwal
07e19c0748 test: unit test cases for core.variables, core.plugin, core.prompt module (#32637) 2026-03-12 11:29:02 +08:00
Dev Sharma
135b3a15a6 test: add UTs for api/ services.plugin (#32588)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-12 11:27:29 +08:00
Rajat Agarwal
0045e387f5 test: unit test cases for core.app.apps module (#32482) 2026-03-12 11:23:25 +08:00
NFish
44713a5c0f fix: allow line breaks when a field value overflows due to excessive length (#33339) 2026-03-12 11:15:29 +08:00
Rajat Agarwal
d5724aebde test: unit test cases core.agent module (#32474) 2026-03-12 11:10:15 +08:00
Rajat Agarwal
c59685748c test: unit test cases for core.callback, core.base, core.entities module (#32471) 2026-03-12 11:09:08 +08:00
Dev Sharma
36c1f4d506 test: improve unit tests for controllers.inner_api (#32203) 2026-03-12 11:07:56 +08:00
Asuka Minato
31eba65fe0 ci: Revert "chore(deps): bump the python-packages group across 1 directory with 13 updates" (#33331)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-03-12 10:48:11 +09:00
dependabot[bot]
72496a5847 chore(deps): bump the python-packages group across 1 directory with 13 updates (#33319)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-03-12 01:22:39 +09:00
dependabot[bot]
8b16030d6b chore(deps-dev): bump the dev group across 1 directory with 18 updates (#33322)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-12 01:02:28 +09:00
盐粒 Yanli
989db0e584 refactor: Unify NodeConfigDict.data and BaseNodeData (#32780)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-11 23:43:58 +08:00
dependabot[bot]
a0f0c97133 chore(deps): bump opentelemetry-propagator-b3 from 1.28.0 to 1.40.0 in /api in the opentelemetry group (#33308)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-03-12 00:37:48 +09:00
577 changed files with 71994 additions and 5665 deletions

View File

@@ -1,33 +1,13 @@
name: Setup Web Environment
description: Setup pnpm, Node.js, and install web dependencies.
inputs:
node-version:
description: Node.js version to use
required: false
default: "22"
install-dependencies:
description: Whether to install web dependencies after setting up Node.js
required: false
default: "true"
runs:
using: composite
steps:
- name: Install pnpm
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
- name: Setup Vite+
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
with:
node-version: ${{ inputs.node-version }}
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: ${{ inputs.install-dependencies == 'true' }}
shell: bash
run: pnpm --dir web install --frozen-lockfile
node-version-file: "./web/.nvmrc"
cache: true
run-install: |
- cwd: ./web
args: ['--frozen-lockfile']

View File

@@ -102,13 +102,11 @@ jobs:
- name: Setup web environment
if: steps.web-changes.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
with:
node-version: "24"
- name: ESLint autofix
if: steps.web-changes.outputs.any_changed == 'true'
run: |
cd web
pnpm eslint --concurrency=2 --prune-suppressions --quiet || true
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3

View File

@@ -62,6 +62,9 @@ jobs:
needs: check-changes
if: needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
with:
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
style-check:
name: Style Check

View File

@@ -88,7 +88,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
pnpm run lint:ci
vp run lint:ci
# pnpm run lint:report
# continue-on-error: true
@@ -102,17 +102,17 @@ jobs:
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run lint:tss
run: vp run lint:tss
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run type-check
run: vp run type-check
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run knip
run: vp run knip
superlinter:
name: SuperLinter

View File

@@ -50,8 +50,6 @@ jobs:
- name: Setup web environment
uses: ./.github/actions/setup-web
with:
install-dependencies: "false"
- name: Detect changed files and generate diff
id: detect_changes

View File

@@ -2,6 +2,13 @@ name: Web Tests
on:
workflow_call:
inputs:
base_sha:
required: false
type: string
head_sha:
required: false
type: string
permissions:
contents: read
@@ -14,6 +21,8 @@ jobs:
test:
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
runs-on: ubuntu-latest
env:
VITEST_COVERAGE_SCOPE: app-components
strategy:
fail-fast: false
matrix:
@@ -34,7 +43,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Run tests
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}
@@ -50,6 +59,8 @@ jobs:
if: ${{ !cancelled() }}
needs: [test]
runs-on: ubuntu-latest
env:
VITEST_COVERAGE_SCOPE: app-components
defaults:
run:
shell: bash
@@ -59,6 +70,7 @@ jobs:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup web environment
@@ -72,7 +84,13 @@ jobs:
merge-multiple: true
- name: Merge reports
run: pnpm vitest --merge-reports --coverage --silent=passed-only
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
- name: Check app/components diff coverage
env:
BASE_SHA: ${{ inputs.base_sha }}
HEAD_SHA: ${{ inputs.head_sha }}
run: node ./scripts/check-components-diff-coverage.mjs
- name: Coverage Summary
if: always()
@@ -429,4 +447,4 @@ jobs:
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run build
run: vp run build

View File

@@ -188,7 +188,6 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
WEAVIATE_TOKENIZATION=word

View File

@@ -43,7 +43,6 @@ forbidden_modules =
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
@@ -90,9 +89,6 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
dify_graph.nodes.agent.agent_node -> core.model_manager
dify_graph.nodes.agent.agent_node -> core.provider_manager
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
dify_graph.nodes.llm.llm_utils -> core.model_manager
dify_graph.nodes.llm.protocols -> core.model_manager
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
@@ -100,8 +96,6 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
dify_graph.nodes.agent.agent_node -> core.agent.entities
dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
@@ -110,12 +104,10 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager
dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
@@ -126,15 +118,11 @@ ignore_imports =
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
dify_graph.nodes.llm.node -> models.dataset
dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.file_saver -> core.tools.signature
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.agent.agent_node -> models
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.agent.agent_node -> services
dify_graph.nodes.tool.tool_node -> services
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis

View File

@@ -17,11 +17,6 @@ class WeaviateConfig(BaseSettings):
default=None,
)
WEAVIATE_GRPC_ENABLED: bool = Field(
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,

View File

@@ -5,6 +5,7 @@ from typing import Any, Literal
from flask import request, send_file
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -169,6 +170,20 @@ register_enum_models(
)
def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
"""
Read the uploaded file and validate its actual size before delegating to the plugin service.
FileStorage.content_length is not reliable for multipart test uploads and may be zero even when
content exists, so the controllers validate against the loaded bytes instead.
"""
content = file.read()
if len(content) > max_size:
raise ValueError("File size exceeds the maximum allowed size")
return content
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@@ -284,12 +299,7 @@ class PluginUploadFromPkgApi(Resource):
_, tenant_id = current_account_with_tenant()
file = request.files["pkg"]
# check file size
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
try:
response = PluginService.upload_pkg(tenant_id, content)
except PluginDaemonClientSideError as e:
@@ -328,12 +338,7 @@ class PluginUploadFromBundleApi(Resource):
_, tenant_id = current_account_with_tenant()
file = request.files["bundle"]
# check file size
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE)
try:
response = PluginService.upload_bundle(tenant_id, content)
except PluginDaemonClientSideError as e:

View File

@@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]):
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try:
data = request.get_json()

View File

@@ -6,6 +6,7 @@ from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
from core.agent.errors import AgentMaxIterationError
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
@@ -22,7 +23,6 @@ from dify_graph.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from dify_graph.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)

9
api/core/agent/errors.py Normal file
View File

@@ -0,0 +1,9 @@
class AgentMaxIterationError(Exception):
"""Raised when an agent runner exceeds the configured max iteration count."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@@ -5,6 +5,7 @@ from copy import deepcopy
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.errors import AgentMaxIterationError
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
@@ -25,7 +26,6 @@ from dify_graph.model_runtime.entities import (
UserPromptMessage,
)
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from dify_graph.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)

View File

@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):

View File

@@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@@ -3,7 +3,10 @@ import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.queue_entities import (
AppQueueEvent,
@@ -30,8 +33,10 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.node_resolution import resolve_workflow_node_class
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
@@ -62,8 +67,6 @@ from dify_graph.graph_events import (
NodeRunSucceededEvent,
)
from dify_graph.graph_events.graph import GraphRunAbortedEvent
from dify_graph.nodes import NodeType
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
@@ -303,10 +306,12 @@ class WorkflowBasedAppRunner:
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
@@ -334,6 +339,18 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
@staticmethod
def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None:
raw_agent_strategy = event.extras.get("agent_strategy")
if raw_agent_strategy is None:
return None
try:
return AgentStrategyInfo.model_validate(raw_agent_strategy)
except ValidationError:
logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True)
return None
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
@@ -419,7 +436,7 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=event.agent_strategy,
agent_strategy=self._build_agent_strategy_info(event),
provider_type=event.provider_type,
provider_id=event.provider_id,
)

View File

@@ -0,0 +1,3 @@
from .agent_strategy import AgentStrategyInfo
__all__ = ["AgentStrategyInfo"]

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel, ConfigDict
class AgentStrategyInfo(BaseModel):
name: str
icon: str | None = None
model_config = ConfigDict(extra="forbid")

View File

@@ -5,8 +5,8 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
@@ -314,7 +314,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
in_iteration_id: str | None = None
in_loop_id: str | None = None
start_at: datetime
agent_strategy: AgentNodeStrategyInit | None = None
agent_strategy: AgentStrategyInfo | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType

View File

@@ -4,8 +4,8 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse):
extras: dict[str, object] = Field(default_factory=dict)
iteration_id: str | None = None
loop_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
agent_strategy: AgentStrategyInfo | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str

View File

@@ -193,7 +193,8 @@ class LLMGenerator:
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
error = str(e)
error_step = "generate rule config"
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@@ -279,7 +280,8 @@ class LLMGenerator:
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
error = str(e)
error_step = "handle unexpected exception"
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""

View File

@@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
except ValueError:
raise
except Exception:
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):

View File

@@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:param credential_type: the type of the credential
:return: the credentials schema of the provider
:param credential_type: the type of the credential, as CredentialType or str; str values
are normalized via CredentialType.of and may raise ValueError for invalid values.
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
empty list for CredentialType.UNAUTHORIZED or missing schemas.
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
Raises ValueError for invalid credential types.
"""
if credential_type == CredentialType.OAUTH2.value:
if isinstance(credential_type, str):
credential_type = CredentialType.of(credential_type)
if credential_type == CredentialType.OAUTH2:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
if credential_type == CredentialType.UNAUTHORIZED:
return []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:

View File

@@ -137,6 +137,7 @@ class ToolFileManager:
session.add(tool_file)
session.commit()
session.refresh(tool_file)
return tool_file

View File

@@ -19,6 +19,7 @@ from core.trigger.debug.events import (
build_plugin_pool_key,
build_webhook_pool_key,
)
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig
@@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC):
app_id: str
user_id: str
tenant_id: str
node_config: Mapping[str, Any]
node_config: NodeConfigDict
node_id: str
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str):
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str):
self.tenant_id = tenant_id
self.user_id = user_id
self.app_id = app_id
@@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
def poll(self) -> TriggerDebugEvent | None:
from services.trigger.trigger_service import TriggerService
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {}))
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True)
provider_id = TriggerProviderID(plugin_trigger_data.provider_id)
pool_key: str = build_plugin_pool_key(
name=plugin_trigger_data.event_name,

View File

@@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -22,7 +22,15 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.entities.graph_config import NodeConfigDict
from core.workflow.node_resolution import resolve_workflow_node_class
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
from core.workflow.nodes.agent.plugin_strategy_adapter import (
PluginAgentStrategyPresentationProvider,
PluginAgentStrategyResolver,
)
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import NodeType, SystemVariableKey
from dify_graph.file.file_manager import file_manager
@@ -31,26 +39,18 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
from dify_graph.nodes.code.entities import CodeLanguage
from dify_graph.nodes.code.limits import CodeNodeLimits
from dify_graph.nodes.datasource import DatasourceNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from dify_graph.nodes.llm.entities import ModelConfig
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
from dify_graph.nodes.http_request import build_http_request_config
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
from dify_graph.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
from dify_graph.nodes.tool.tool_node import ToolNode
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
@@ -60,6 +60,9 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
def fetch_memory(
*,
conversation_id: str | None,
@@ -100,10 +103,7 @@ class DefaultWorkflowCodeExecutor:
@final
class DifyNodeFactory(NodeFactory):
"""
Default implementation of NodeFactory that uses the traditional node mapping.
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
and instantiating the appropriate node class.
Default implementation of NodeFactory that resolves node classes from the live registry.
"""
def __init__(
@@ -146,6 +146,10 @@ class DifyNodeFactory(NodeFactory):
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
self._agent_strategy_resolver = PluginAgentStrategyResolver()
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
self._agent_runtime_support = AgentRuntimeSupport()
self._agent_message_transformer = AgentMessageTransformer()
@staticmethod
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
@@ -157,178 +161,125 @@ class DifyNodeFactory(NodeFactory):
return DifyRunContext.model_validate(raw_ctx)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
:raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
# Get node_id from config
node_id = node_config["id"]
# Get node type from config
node_data = node_config["data"]
try:
node_type = NodeType(node_data["type"])
except ValueError:
raise ValueError(f"Unknown node type: {node_data['type']}")
# Get node class
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
node_version = str(node_data.get("version", "1"))
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
max_output_length=self._template_transform_max_output_length,
)
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_request_config=self._http_request_config,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.HUMAN_INPUT:
return HumanInputNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
)
if node_type == NodeType.KNOWLEDGE_INDEX:
return KnowledgeIndexNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
index_processor=IndexProcessor(),
summary_index_service=SummaryIndex(),
)
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return LLMNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.DATASOURCE:
return DatasourceNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
datasource_manager=DatasourceManager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
rag_retrieval=self._rag_retrieval,
)
if node_type == NodeType.DOCUMENT_EXTRACTOR:
return DocumentExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
unstructured_api_config=self._document_extractor_unstructured_api_config,
http_client=self._http_request_http_client,
)
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.TOOL:
return ToolNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
)
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
NodeType.CODE: lambda: {
"code_executor": self._code_executor,
"code_limits": self._code_limits,
},
NodeType.TEMPLATE_TRANSFORM: lambda: {
"template_renderer": self._template_renderer,
"max_output_length": self._template_transform_max_output_length,
},
NodeType.HTTP_REQUEST: lambda: {
"http_request_config": self._http_request_config,
"http_client": self._http_request_http_client,
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
"file_manager": self._http_request_file_manager,
},
NodeType.HUMAN_INPUT: lambda: {
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
},
NodeType.KNOWLEDGE_INDEX: lambda: {
"index_processor": IndexProcessor(),
"summary_index_service": SummaryIndex(),
},
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
NodeType.DATASOURCE: lambda: {
"datasource_manager": DatasourceManager,
},
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
"rag_retrieval": self._rag_retrieval,
},
NodeType.DOCUMENT_EXTRACTOR: lambda: {
"unstructured_api_config": self._document_extractor_unstructured_api_config,
"http_client": self._http_request_http_client,
},
NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=False,
),
NodeType.TOOL: lambda: {
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
},
NodeType.AGENT: lambda: {
"strategy_resolver": self._agent_strategy_resolver,
"presentation_provider": self._agent_strategy_presentation_provider,
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
},
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
)
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
node_data_model = ModelConfig.model_validate(node_data["model"])
@staticmethod
def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData:
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
return node_class.validate_node_data(node_data)
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
def _build_llm_compatible_node_init_kwargs(
self,
*,
node_class: type[Node],
node_data: BaseNodeData,
include_http_client: bool,
) -> dict[str, object]:
validated_node_data = cast(
LLMCompatibleNodeData,
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
)
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,
"model_factory": self._llm_model_factory,
"model_instance": model_instance,
"memory": self._build_memory_for_llm_node(
node_data=validated_node_data,
model_instance=model_instance,
),
}
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
return node_init_kwargs
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
node_data_model = node_data.model
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
@@ -364,14 +315,12 @@ class DifyNodeFactory(NodeFactory):
def _build_memory_for_llm_node(
self,
*,
node_data: Mapping[str, Any],
node_data: LLMCompatibleNodeData,
model_instance: ModelInstance,
) -> PromptMessageMemory | None:
raw_memory_config = node_data.get("memory")
if raw_memory_config is None:
if node_data.memory is None:
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
@@ -381,6 +330,6 @@ class DifyNodeFactory(NodeFactory):
return fetch_memory(
conversation_id=conversation_id,
app_id=self._dify_context.app_id,
node_data_memory=node_memory,
node_data_memory=node_data.memory,
model_instance=model_instance,
)

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from collections.abc import Mapping
from importlib import import_module
from dify_graph.enums import NodeType
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.node_mapping import LATEST_VERSION, get_node_type_classes_mapping
_WORKFLOW_NODE_MODULES = ("core.workflow.nodes.agent",)
_workflow_nodes_registered = False
def ensure_workflow_nodes_registered() -> None:
"""Import workflow-local node modules so they can register with `Node.__init_subclass__`."""
global _workflow_nodes_registered
if _workflow_nodes_registered:
return
for module_name in _WORKFLOW_NODE_MODULES:
import_module(module_name)
_workflow_nodes_registered = True
def get_workflow_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
ensure_workflow_nodes_registered()
return get_node_type_classes_mapping()
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
node_mapping = get_workflow_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
return node_class

View File

@@ -0,0 +1,4 @@
from .agent_node import AgentNode
from .entities import AgentNodeData
__all__ = ["AgentNode", "AgentNodeData"]

View File

@@ -0,0 +1,188 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, SystemVariableKey, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from .entities import AgentNodeData
from .exceptions import (
AgentInvocationError,
AgentMessageTransformError,
)
from .message_transformer import AgentMessageTransformer
from .runtime_support import AgentRuntimeSupport
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
class AgentNode(Node[AgentNodeData]):
node_type = NodeType.AGENT
_strategy_resolver: AgentStrategyResolver
_presentation_provider: AgentStrategyPresentationProvider
_runtime_support: AgentRuntimeSupport
_message_transformer: AgentMessageTransformer
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._strategy_resolver = strategy_resolver
self._presentation_provider = presentation_provider
self._runtime_support = runtime_support
self._message_transformer = message_transformer
@classmethod
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
dify_ctx = self.require_dify_context()
event.extras["agent_strategy"] = {
"name": self.node_data.agent_strategy_name,
"icon": self._presentation_provider.get_icon(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
),
}
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = self._strategy_resolver.resolve(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
parameters = self._runtime_support.build_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
)
parameters_for_log = self._runtime_support.build_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
for_log=True,
)
credentials = self._runtime_support.build_credentials(parameters=parameters)
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
try:
yield from self._message_transformer.transform(
messages=message_stream,
tool_info={
"icon": self._presentation_provider.get_icon(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
),
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
result: dict[str, Any] = {}
typed_node_data = node_data
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result

View File

@@ -5,13 +5,15 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class AgentNodeData(BaseNodeData):
agent_strategy_provider_name: str # redundancy
type: NodeType = NodeType.AGENT
agent_strategy_provider_name: str
agent_strategy_name: str
agent_strategy_label: str # redundancy
agent_strategy_label: str
memory: MemoryConfig | None = None
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version

View File

@@ -119,14 +119,3 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@@ -0,0 +1,292 @@
from __future__ import annotations
from collections.abc import Generator, Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from dify_graph.variables.segments import ArrayFileSegment
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
class AgentMessageTransformer:
def transform(
self,
*,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
for log in agent_logs:
if log.message_id == agent_log.message_id:
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
json_output: list[dict[str, Any] | list[Any]] = []
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from factories.agent_factory import get_plugin_agent_strategy
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
class PluginAgentStrategyResolver(AgentStrategyResolver):
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> ResolvedAgentStrategy:
return get_plugin_agent_strategy(
tenant_id=tenant_id,
agent_strategy_provider_name=agent_strategy_provider_name,
agent_strategy_name=agent_strategy_name,
)
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
try:
plugins = manager.list_plugins(tenant_id)
except Exception:
return None
try:
current_plugin = next(
plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name
)
except StopIteration:
return None
return current_plugin.declaration.icon

View File

@@ -0,0 +1,276 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from dify_graph.enums import SystemVariableKey
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
class AgentRuntimeSupport:
def build_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: Any,
for_log: bool = False,
) -> dict[str, Any]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id,
app_id,
entity,
invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
memory = self.fetch_memory(
variable_pool=variable_pool,
app_id=app_id,
model_instance=model_instance,
)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
return credentials
def fetch_memory(
self,
*,
variable_pool: VariablePool,
app_id: str,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
model_type=ModelType.LLM,
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
@staticmethod
def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]:
try:
AgentOldVersionModelFeatures(feature.value)
except ValueError:
model_schema.features.remove(feature)
return model_schema
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
from collections.abc import Generator, Sequence
from typing import Any, Protocol
from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.entities.request import InvokeCredentials
from core.tools.entities.tool_entities import ToolInvokeMessage
class ResolvedAgentStrategy(Protocol):
meta_version: str | None
def get_parameters(self) -> Sequence[AgentStrategyParameter]: ...
def invoke(
self,
*,
params: dict[str, Any],
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
credentials: InvokeCredentials | None = None,
) -> Generator[ToolInvokeMessage, None, None]: ...
class AgentStrategyResolver(Protocol):
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> ResolvedAgentStrategy: ...
class AgentStrategyPresentationProvider(Protocol):
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ...

View File

@@ -9,9 +9,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.node_resolution import resolve_workflow_node_class
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.errors import WorkflowNodeRunFailedError
from dify_graph.file.models import File
from dify_graph.graph import Graph
@@ -23,7 +24,6 @@ from dify_graph.graph_engine.protocols.command_channel import CommandChannel
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from dify_graph.nodes import NodeType
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
@@ -212,7 +212,7 @@ class WorkflowEntry:
node_config_data = node_config["data"]
# Get node type
node_type = NodeType(node_config_data["type"])
node_type = node_config_data.type
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -234,8 +234,7 @@ class WorkflowEntry:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
typed_node_config = cast(dict[str, object], node_config)
node = cast(Any, node_factory).create_node(typed_node_config)
node = node_factory.create_node(node_config)
node_cls = type(node)
try:
@@ -344,7 +343,7 @@ class WorkflowEntry:
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1")
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
@@ -371,10 +370,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config: NodeConfigDict = {
"id": node_id,
"data": cast(NodeConfigData, node_data),
}
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@@ -1,11 +1,9 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_start_reason import WorkflowStartReason
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"WorkflowExecution",
"WorkflowNodeExecution",

View File

@@ -1,8 +0,0 @@
from pydantic import BaseModel
class AgentNodeStrategyInit(BaseModel):
"""Agent node strategy initialization data."""
name: str
icon: str | None = None

View File

@@ -0,0 +1,176 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator
from dify_graph.entities.exc import DefaultValueTypeError
from dify_graph.enums import ErrorStrategy, NodeType
# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`.
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
# and persisted templates/workflows also carry undeclared compatibility keys such as
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
# here until graph parsing becomes discriminated by node type or those legacy payloads
# are normalized.
model_config = ConfigDict(extra="allow")
type: NodeType
title: str = ""
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = Field(default_factory=RetryConfig)
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
def __getitem__(self, key: str) -> Any:
"""
Dict-style access without calling model_dump() on every lookup.
Prefer using model fields and Pydantic's extra storage.
"""
# First, check declared model fields
if key in self.__class__.model_fields:
return getattr(self, key)
# Then, check undeclared compatibility fields stored in Pydantic's extra dict.
extras = getattr(self, "__pydantic_extra__", None)
if extras is None:
extras = getattr(self, "model_extra", None)
if extras is not None and key in extras:
return extras[key]
raise KeyError(key)
def get(self, key: str, default: Any = None) -> Any:
"""
Dict-style .get() without calling model_dump() on every lookup.
"""
if key in self.__class__.model_fields:
return getattr(self, key)
extras = getattr(self, "__pydantic_extra__", None)
if extras is None:
extras = getattr(self, "model_extra", None)
if extras is not None and key in extras:
return extras.get(key, default)
return default

View File

@@ -4,21 +4,20 @@ import sys
from pydantic import TypeAdapter, with_config
from dify_graph.entities.base_node_data import BaseNodeData
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
@with_config(extra="allow")
class NodeConfigData(TypedDict):
type: str
@with_config(extra="allow")
class NodeConfigDict(TypedDict):
id: str
data: NodeConfigData
# This is the permissive raw graph boundary. Node factories re-validate `data`
# with the concrete `NodeData` subtype after resolving the node implementation.
data: BaseNodeData
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)

View File

@@ -8,7 +8,7 @@ from typing import Protocol, cast, final
from pydantic import TypeAdapter
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
from dify_graph.nodes.base.node import Node
from libs.typing import is_str
@@ -34,7 +34,8 @@ class NodeFactory(Protocol):
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
:raises ValueError: if node type is unknown or no implementation exists for the resolved version
:raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation
"""
...
@@ -115,10 +116,7 @@ class Graph:
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid]["data"]
node_type = node_data["type"]
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
if node_data.type.is_start_node:
start_node_id = nid
break
@@ -203,6 +201,23 @@ class Graph:
return GraphBuilder(graph_cls=cls)
@staticmethod
def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]:
"""
Remove editor-only nodes before `NodeConfigDict` validation.
Persisted note widgets use a top-level `type == "custom-note"` but leave
`data.type` empty because they are never executable graph nodes. Filter
them while configs are still raw dicts so Pydantic does not validate
their placeholder payloads against `BaseNodeData.type: NodeType`.
"""
filtered_node_configs: list[dict[str, object]] = []
for node_config in node_configs:
if node_config.get("type", "") == "custom-note":
continue
filtered_node_configs.append(dict(node_config))
return filtered_node_configs
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
@@ -302,13 +317,13 @@ class Graph:
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
node_configs = cls._filter_canvas_only_nodes(node_configs)
node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)

View File

@@ -4,7 +4,6 @@ from datetime import datetime
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@@ -13,8 +12,8 @@ from .base import GraphNodeEventBase
class NodeRunStartedEvent(GraphNodeEventBase):
node_title: str
predecessor_node_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
start_at: datetime = Field(..., description="node start time")
extras: dict[str, object] = Field(default_factory=dict)
# FIXME(-LAN-): only for ToolNode
provider_type: str = ""

View File

@@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
if not super().is_empty() and not self.tool_call_id:
return False
return True
return super().is_empty() and not self.tool_call_id

View File

@@ -4,7 +4,8 @@ class InvokeError(ValueError):
description: str | None = None
def __init__(self, description: str | None = None):
self.description = description
if description is not None:
self.description = description
def __str__(self):
return self.description or self.__class__.__name__

View File

@@ -282,7 +282,8 @@ class ModelProviderFactory:
all_model_type_models.append(model_schema)
simple_provider_schema = provider_schema.to_simple_provider()
simple_provider_schema.models.extend(all_model_type_models)
if model_type:
simple_provider_schema.models = all_model_type_models
providers.append(simple_provider_schema)

View File

@@ -1,3 +0,0 @@
from .agent_node import AgentNode
__all__ = ["AgentNode"]

View File

@@ -1,762 +0,0 @@
from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.enums import (
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayFileSegment, StringSegment
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentNodeError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
if TYPE_CHECKING:
from core.agent.strategy.plugin import PluginAgentStrategy
from core.plugin.entities.request import InvokeCredentials
class AgentNode(Node[AgentNodeData]):
"""
Agent Node
"""
node_type = NodeType.AGENT
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = get_plugin_agent_strategy(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
# get parameters
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
strategy=strategy,
)
credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
try:
yield from self._transform_message(
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
def _generate_agent_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (AgentNodeData): The data associated with the agent node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
dify_ctx = self.require_dify_context()
tool_runtime = ToolManager.get_agent_tool_runtime(
dify_ctx.tenant_id,
dify_ctx.app_id,
entity,
dify_ctx.invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value)
# memory config
history_prompt_messages = []
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
# remove structured output feature to support old version agent plugin
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
from core.plugin.entities.request import InvokeCredentials
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AgentNodeData.model_validate(node_data)
result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result
@property
def agent_strategy_icon(self) -> str | None:
"""
Get agent strategy icon
:return:
"""
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
dify_ctx = self.require_dify_context()
plugins = manager.list_plugins(dify_ctx.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
dify_ctx = self.require_dify_context()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
dify_ctx = self.require_dify_context()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_name
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=dify_ctx.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
try:
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
except ValueError:
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
:param tool: tool
:return: filtered tool dict
"""
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
msg_metadata = {}
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.message_id == agent_log.message_id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: AnswerNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AnswerNodeData.model_validate(node_data)
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
_ = graph_config # Explicitly mark as unused
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}

View File

@@ -3,7 +3,8 @@ from enum import StrEnum, auto
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class AnswerNodeData(BaseNodeData):
@@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData):
Answer Node Data.
"""
type: NodeType = NodeType.ANSWER
answer: str = Field(..., description="answer template string")

View File

@@ -1,4 +1,4 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [
@@ -6,6 +6,5 @@ __all__ = [
"BaseIterationState",
"BaseLoopNodeData",
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
]

View File

@@ -1,31 +1,12 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
from typing import Any
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel, field_validator
from dify_graph.enums import ErrorStrategy
from .exc import DefaultValueTypeError
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
from dify_graph.entities.base_node_data import BaseNodeData
class VariableSelector(BaseModel):
@@ -76,112 +57,6 @@ class OutputVariableEntity(BaseModel):
return v
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
class BaseIterationNodeData(BaseNodeData):
start_node_id: str | None = None

View File

@@ -11,7 +11,9 @@ from types import MappingProxyType
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
from dify_graph.entities import GraphInitParams
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import (
ErrorStrategy,
@@ -62,8 +64,6 @@ from dify_graph.node_events import (
from dify_graph.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from .entities import BaseNodeData, RetryConfig
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
_MISSING_RUN_CONTEXT_VALUE = object()
@@ -153,11 +153,11 @@ class Node(Generic[NodeDataT]):
Later, in __init__:
::
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
CodeNodeData instance
(stored in self._node_data)
config["data"] ──► _node_data_type.model_validate(..., from_attributes=True)
CodeNodeData instance
(stored in self._node_data)
Example:
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
@@ -241,7 +241,7 @@ class Node(Generic[NodeDataT]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
@@ -254,22 +254,21 @@ class Node(Generic[NodeDataT]):
self.graph_runtime_state = graph_runtime_state
self.state: NodeState = NodeState.UNKNOWN # node execution state
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
node_id = config["id"]
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
raise ValueError("Node config data must be a mapping.")
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
self._node_data = self.validate_node_data(config["data"])
self.post_init()
@classmethod
def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT:
"""Validate shared graph node payloads against the subclass-declared NodeData model."""
return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True))
def post_init(self) -> None:
"""Optional hook for subclasses requiring extra initialization."""
return
@@ -342,9 +341,6 @@ class Node(Generic[NodeDataT]):
return None
return str(execution_id)
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
@@ -353,6 +349,10 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
"""Allow subclasses to enrich the started event without cross-node imports in the base class."""
_ = event
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@@ -366,41 +366,10 @@ class Node(Generic[NodeDataT]):
in_iteration_id=None,
start_at=self._start_at,
)
# === FIXME(-LAN-): Needs to refactor.
from dify_graph.nodes.tool.tool_node import ToolNode
if isinstance(self, ToolNode):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from dify_graph.nodes.datasource.datasource_node import DatasourceNode
if isinstance(self, DatasourceNode):
plugin_id = getattr(self.node_data, "plugin_id", "")
provider_name = getattr(self.node_data, "provider_name", "")
start_event.provider_id = f"{plugin_id}/{provider_name}"
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
if isinstance(self, TriggerEventNode):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from typing import cast
from dify_graph.nodes.agent.agent_node import AgentNode
from dify_graph.nodes.agent.entities import AgentNodeData
if isinstance(self, AgentNode):
start_event.agent_strategy = AgentNodeStrategyInit(
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
icon=self.agent_strategy_icon,
)
# ===
try:
self.populate_start_event(start_event)
except Exception:
logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True)
yield start_event
try:
@@ -442,7 +411,7 @@ class Node(Generic[NodeDataT]):
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
config: NodeConfigDict,
) -> Mapping[str, Sequence[str]]:
"""Extracts references variable selectors from node configuration.
@@ -480,13 +449,12 @@ class Node(Generic[NodeDataT]):
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
# Pass raw dict data instead of creating NodeData instance
node_id = config["id"]
node_data = cls.validate_node_data(config["data"])
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
graph_config=graph_config,
node_id=node_id,
node_data=node_data,
)
return data
@@ -496,7 +464,7 @@ class Node(Generic[NodeDataT]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: NodeDataT,
) -> Mapping[str, Sequence[str]]:
return {}
@@ -520,10 +488,8 @@ class Node(Generic[NodeDataT]):
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/dify_graph/nodes/__init__.py`.
# NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so
# `Node.get_node_type_classes_mapping()` can resolve numeric versions and `latest`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
@@ -531,7 +497,9 @@ class Node(Generic[NodeDataT]):
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under dify_graph.nodes so subclasses register themselves on import.
Then we return a readonly view of the registry to avoid accidental mutation.
Callers that rely on workflow-local nodes defined outside `dify_graph.nodes` must import
those modules before invoking this method so they can register through `__init_subclass__`.
We then return a readonly view of the registry to avoid accidental mutation.
"""
# Import all node modules to ensure they are loaded (thus registered)
import dify_graph.nodes as _nodes_pkg

View File

@@ -3,6 +3,7 @@ from decimal import Decimal
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Protocol, cast
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: CodeNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
for variable_selector in node_data.variables
}
@property

View File

@@ -3,7 +3,8 @@ from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.variables.types import SegmentType
@@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = NodeType.CODE
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, "CodeNodeData.Output"] | None = None

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
@@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
datasource_manager: DatasourceManagerProtocol,
@@ -47,6 +48,10 @@ class DatasourceNode(Node[DatasourceNodeData]):
)
self.datasource_manager = datasource_manager
def populate_start_event(self, event) -> None:
event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}"
event.provider_type = self.node_data.provider_type
def _run(self) -> Generator:
"""
Run the datasource node
@@ -181,7 +186,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: DatasourceNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -190,11 +195,10 @@ class DatasourceNode(Node[DatasourceNodeData]):
:param node_data: node data
:return:
"""
typed_node_data = DatasourceNodeData.model_validate(node_data)
result = {}
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
input = node_data.datasource_parameters[parameter_name]
match input.type:
case "mixed":
assert isinstance(input.value, str)

View File

@@ -3,7 +3,8 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class DatasourceEntity(BaseModel):
@@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel):
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
type: NodeType = NodeType.DATASOURCE
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]

View File

@@ -1,10 +1,12 @@
from collections.abc import Sequence
from dataclasses import dataclass
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class DocumentExtractorNodeData(BaseNodeData):
type: NodeType = NodeType.DOCUMENT_EXTRACTOR
variable_selector: Sequence[str]

View File

@@ -21,6 +21,7 @@ from docx.oxml.text.paragraph import CT_P
from docx.table import Table
from docx.text.paragraph import Paragraph
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod, file_manager
from dify_graph.node_events import NodeRunResult
@@ -54,7 +55,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -136,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
return {node_id + ".files": typed_node_data.variable_selector}
_ = graph_config # Explicitly mark as unused
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(

View File

@@ -1,6 +1,8 @@
from pydantic import BaseModel, Field
from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import OutputVariableEntity
class EndNodeData(BaseNodeData):
@@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData):
END Node Data.
"""
type: NodeType = NodeType.END
outputs: list[OutputVariableEntity]

View File

@@ -8,7 +8,8 @@ import charset_normalizer
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
@@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = NodeType.HTTP_REQUEST
method: Literal[
"get",
"post",

View File

@@ -3,6 +3,7 @@ import mimetypes
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod
from dify_graph.node_events import NodeRunResult
@@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -163,18 +164,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: HttpRequestNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData.model_validate(node_data)
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if typed_node_data.body:
body_type = typed_node_data.body.type
data = typed_node_data.body.data
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
body_type = node_data.body.type
data = node_data.body.data
match body_type:
case "none":
pass

View File

@@ -10,7 +10,8 @@ from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.runtime import VariablePool
from dify_graph.variables.consts import SELECTORS_LENGTH
@@ -71,8 +72,8 @@ class EmailDeliveryConfig(BaseModel):
body: str
debug_mode: bool = False
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
if not user_id:
def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig":
if user_id is None:
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
return self.model_copy(update={"recipients": debug_recipients})
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
@@ -140,7 +141,7 @@ def apply_debug_email_recipient(
method: DeliveryChannelConfig,
*,
enabled: bool,
user_id: str,
user_id: str | None,
) -> DeliveryChannelConfig:
if not enabled:
return method
@@ -148,7 +149,7 @@ def apply_debug_email_recipient(
return method
if not method.config.debug_mode:
return method
debug_config = method.config.with_debug_recipient(user_id or "")
debug_config = method.config.with_debug_recipient(user_id)
return method.model_copy(update={"config": debug_config})
@@ -214,6 +215,7 @@ class UserAction(BaseModel):
class HumanInputNodeData(BaseNodeData):
"""Human Input node data."""
type: NodeType = NodeType.HUMAN_INPUT
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
form_content: str = ""
inputs: list[FormInput] = Field(default_factory=list)

View File

@@ -3,6 +3,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import (
@@ -63,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository,
@@ -348,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: HumanInputNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
@@ -357,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]):
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
validated_node_data = HumanInputNodeData.model_validate(node_data)
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
return node_data.extract_variable_selector_to_variable_mapping(node_id)

View File

@@ -2,7 +2,8 @@ from typing import Literal
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.utils.condition.entities import Condition
@@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData):
If Else Node Data.
"""
type: NodeType = NodeType.IF_ELSE
class Case(BaseModel):
"""
Case entity representing a single logical condition group

View File

@@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData.model_validate(node_data)
var_mapping: dict[str, list[str]] = {}
for case in typed_node_data.cases or []:
_ = graph_config # Explicitly mark as unused
for case in node_data.cases or []:
for condition in case.conditions:
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
var_mapping[key] = condition.variable_selector

View File

@@ -3,7 +3,9 @@ from typing import Any
from pydantic import Field
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState
class ErrorHandleMode(StrEnum):
@@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData):
Iteration Node Data.
"""
type: NodeType = NodeType.ITERATION
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
@@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData):
Iteration Start Node Data.
"""
pass
type: NodeType = NodeType.ITERATION_START
class IterationState(BaseIterationState):

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from typing_extensions import TypeIs
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
NodeExecutionType,
NodeType,
@@ -460,21 +461,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: IterationNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IterationNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
f"{node_id}.input_selector": node_data.iterator_selector,
}
iteration_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
node_config_data = node.get("data", {})
if node_config_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)
@@ -488,16 +486,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# variable selector to variable mapping
try:
# Get node class
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.nodes.node_mapping import get_node_type_classes_mapping
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING:
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
node_type = typed_sub_node_config["data"].type
node_mapping = get_node_type_classes_mapping()
if node_type not in node_mapping:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
node_version = str(typed_sub_node_config["data"].version)
node_cls = node_mapping[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
graph_config=graph_config, config=typed_sub_node_config
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:

View File

@@ -3,7 +3,8 @@ from typing import Literal, Union
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class RerankingModelConfig(BaseModel):
@@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData):
Knowledge index Node Data.
"""
type: str = "knowledge-index"
type: NodeType = NodeType.KNOWLEDGE_INDEX
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None

View File

@@ -2,6 +2,7 @@ import logging
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
from dify_graph.node_events import NodeRunResult
@@ -30,7 +31,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
index_processor: IndexProcessorProtocol,

View File

@@ -3,7 +3,8 @@ from typing import Literal
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
@@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
Knowledge retrieval Node Data.
"""
type: str = "knowledge-retrieval"
type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL
query_variable_selector: list[str] | None | str = None
query_attachment_selector: list[str] | None | str = None
dataset_ids: list[str]

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
@@ -49,7 +50,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
@@ -301,15 +302,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: KnowledgeRetrievalNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
if typed_node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
if node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
if node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector
return variable_mapping

View File

@@ -3,7 +3,8 @@ from enum import StrEnum
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class FilterOperator(StrEnum):
@@ -62,6 +63,7 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
type: NodeType = NodeType.LIST_OPERATOR
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderByConfig

View File

@@ -4,8 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base import BaseNodeData
from dify_graph.nodes.base.entities import VariableSelector
@@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
class LLMNodeData(BaseNodeData):
type: NodeType = NodeType.LLM
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)

View File

@@ -21,6 +21,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
SystemVariableKey,
@@ -121,7 +122,7 @@ class LLMNode(Node[LLMNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
@@ -954,14 +955,11 @@ class LLMNode(Node[LLMNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
@@ -979,7 +977,7 @@ class LLMNode(Node[LLMNodeData]):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = typed_node_data.memory
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@@ -987,16 +985,16 @@ class LLMNode(Node[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if typed_node_data.context.enabled:
variable_mapping["#context#"] = typed_node_data.context.variable_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if typed_node_data.vision.enabled:
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if typed_node_data.memory:
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
if typed_node_data.prompt_config:
if node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
@@ -1009,7 +1007,7 @@ class LLMNode(Node[LLMNodeData]):
break
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}

View File

@@ -3,7 +3,9 @@ from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
@@ -39,6 +41,7 @@ class LoopVariableData(BaseModel):
class LoopNodeData(BaseLoopNodeData):
type: NodeType = NodeType.LOOP
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
@@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData):
Loop Start Node Data.
"""
pass
type: NodeType = NodeType.LOOP_START
class LoopEndNodeData(BaseNodeData):
@@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData):
Loop End Node Data.
"""
pass
type: NodeType = NodeType.LOOP_END
class LoopState(BaseLoopState):

View File

@@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
NodeExecutionType,
NodeType,
@@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = LoopNodeData.model_validate(node_data)
variable_mapping = {}
# Extract loop node IDs statically from graph_config
@@ -318,16 +316,18 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
# variable selector to variable mapping
try:
# Get node class
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.nodes.node_mapping import get_node_type_classes_mapping
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING:
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
node_type = typed_sub_node_config["data"].type
node_mapping = get_node_type_classes_mapping()
if node_type not in node_mapping:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
node_version = str(typed_sub_node_config["data"].version)
node_cls = node_mapping[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
graph_config=graph_config, config=typed_sub_node_config
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
@@ -342,7 +342,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in typed_node_data.loop_variables or []:
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping

View File

@@ -5,5 +5,24 @@ from dify_graph.nodes.base.node import Node
LATEST_VERSION = "latest"
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks dify_graph.nodes
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return the live node registry after importing all `dify_graph.nodes` modules."""
return Node.get_node_type_classes_mapping()
def resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
return node_class
# Snapshot kept for compatibility with older tests; production paths should use the live helpers.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping()

View File

@@ -8,7 +8,8 @@ from pydantic import (
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
from dify_graph.variables.types import SegmentType
@@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData):
Parameter Extractor Node Data.
"""
type: NodeType = NodeType.PARAMETER_EXTRACTOR
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]

View File

@@ -10,6 +10,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
@@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -837,15 +838,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: ParameterExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
if node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.llm import ModelConfig, VisionConfig
@@ -11,6 +12,7 @@ class ClassConfig(BaseModel):
class QuestionClassifierNodeData(BaseNodeData):
type: NodeType = NodeType.QUESTION_CLASSIFIER
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]

View File

@@ -7,6 +7,7 @@ from core.model_manager import ModelInstance
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeExecutionType,
NodeType,
@@ -62,7 +63,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -251,16 +252,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: QuestionClassifierNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector}
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors: list[VariableSelector] = []
if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)

View File

@@ -2,7 +2,8 @@ from collections.abc import Sequence
from pydantic import Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.variables.input_entities import VariableEntity
@@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData):
Start Node Data
"""
type: NodeType = NodeType.START
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@@ -1,4 +1,5 @@
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import VariableSelector
@@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData):
Template Transform Node Data.
"""
type: NodeType = NodeType.TEMPLATE_TRANSFORM
variables: list[VariableSelector]
template: str

View File

@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@@ -25,7 +26,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -86,12 +87,9 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
for variable_selector in node_data.variables
}

View File

@@ -4,7 +4,8 @@ from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class ToolEntity(BaseModel):
@@ -32,6 +33,8 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = NodeType.TOOL
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]

View File

@@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
NodeType,
SystemVariableKey,
@@ -46,7 +47,7 @@ class ToolNode(Node[ToolNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -64,6 +65,10 @@ class ToolNode(Node[ToolNodeData]):
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
event.provider_id = self.node_data.provider_id
event.provider_type = self.node_data.provider_type
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node
@@ -484,7 +489,7 @@ class ToolNode(Node[ToolNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: ToolNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -493,9 +498,8 @@ class ToolNode(Node[ToolNodeData]):
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
typed_node_data = node_data
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]

View File

@@ -4,13 +4,16 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.trigger.entities.entities import EventParameter
from dify_graph.nodes.base.entities import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError
class TriggerEventNodeData(BaseNodeData):
"""Plugin trigger node data"""
type: NodeType = NodeType.TRIGGER_PLUGIN
class TriggerEventInput(BaseModel):
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@@ -38,8 +41,6 @@ class TriggerEventNodeData(BaseNodeData):
raise ValueError("value must be a string, int, float, bool or dict")
return type
title: str
desc: str | None = None
plugin_id: str = Field(..., description="Plugin ID")
provider_id: str = Field(..., description="Provider ID")
event_name: str = Field(..., description="Event name")

View File

@@ -32,6 +32,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
event.provider_id = self.node_data.provider_id
def _run(self) -> NodeRunResult:
"""
Run the plugin trigger node.

View File

@@ -2,7 +2,8 @@ from typing import Literal, Union
from pydantic import BaseModel, Field
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class TriggerScheduleNodeData(BaseNodeData):
@@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData):
Trigger Schedule Node Data
"""
type: NodeType = NodeType.TRIGGER_SCHEDULE
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")

View File

@@ -1,4 +1,4 @@
from dify_graph.nodes.base.exc import BaseNodeError
from dify_graph.entities.exc import BaseNodeError
class ScheduleNodeError(BaseNodeError):

View File

@@ -1,10 +1,41 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel, Field, field_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.variables.types import SegmentType
_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
}
)
_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
}
)
_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES
_WEBHOOK_BODY_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_BOOLEAN,
SegmentType.ARRAY_OBJECT,
SegmentType.FILE,
}
)
class Method(StrEnum):
@@ -25,29 +56,34 @@ class ContentType(StrEnum):
class WebhookParameter(BaseModel):
"""Parameter definition for headers, query params, or body."""
"""Parameter definition for headers or query params."""
name: str
type: SegmentType = SegmentType.STRING
required: bool = False
@field_validator("type", mode="after")
@classmethod
def validate_type(cls, v: SegmentType) -> SegmentType:
if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook parameter type: {v}")
return v
class WebhookBodyParameter(BaseModel):
"""Body parameter with type information."""
name: str
type: Literal[
"string",
"number",
"boolean",
"object",
"array[string]",
"array[number]",
"array[boolean]",
"array[object]",
"file",
] = "string"
type: SegmentType = SegmentType.STRING
required: bool = False
@field_validator("type", mode="after")
@classmethod
def validate_type(cls, v: SegmentType) -> SegmentType:
if v not in _WEBHOOK_BODY_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook body parameter type: {v}")
return v
class WebhookData(BaseNodeData):
"""
@@ -57,6 +93,7 @@ class WebhookData(BaseNodeData):
class SyncMode(StrEnum):
SYNC = "async" # only support
type: NodeType = NodeType.TRIGGER_WEBHOOK
method: Method = Method.GET
content_type: ContentType = Field(default=ContentType.JSON)
headers: Sequence[WebhookParameter] = Field(default_factory=list)
@@ -71,6 +108,22 @@ class WebhookData(BaseNodeData):
return v.lower()
return v
@field_validator("headers", mode="after")
@classmethod
def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
for param in v:
if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook header parameter type: {param.type}")
return v
@field_validator("params", mode="after")
@classmethod
def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
for param in v:
if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook query parameter type: {param.type}")
return v
status_code: int = 200 # Expected status code for response
response_body: str = "" # Template for response body

View File

@@ -1,4 +1,4 @@
from dify_graph.nodes.base.exc import BaseNodeError
from dify_graph.entities.exc import BaseNodeError
class WebhookNodeError(BaseNodeError):

View File

@@ -152,7 +152,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = raw_data
continue
if param_type == "file":
if param_type == SegmentType.FILE:
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.variables.types import SegmentType
@@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData):
Variable Aggregator Node Data.
"""
type: NodeType = NodeType.VARIABLE_AGGREGATOR
output_type: str
variables: list[list[str]]
advanced_settings: AdvancedSettings | None = None

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@@ -22,7 +23,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
@@ -52,21 +53,18 @@ class VariableAssignerNode(Node[VariableAssignerData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: VariableAssignerData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerData.model_validate(node_data)
mapping = {}
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
assigned_variable_node_id = node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(typed_node_data.assigned_variable_selector)
selector_key = ".".join(node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.assigned_variable_selector
mapping[key] = node_data.assigned_variable_selector
selector_key = ".".join(typed_node_data.input_variable_selector)
selector_key = ".".join(node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.input_variable_selector
mapping[key] = node_data.input_variable_selector
return mapping
def _run(self) -> NodeRunResult:

View File

@@ -1,7 +1,8 @@
from collections.abc import Sequence
from enum import StrEnum
from dify_graph.nodes.base import BaseNodeData
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
class WriteMode(StrEnum):
@@ -11,6 +12,7 @@ class WriteMode(StrEnum):
class VariableAssignerData(BaseNodeData):
type: NodeType = NodeType.VARIABLE_ASSIGNER
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]

Some files were not shown because too many files have changed in this diff Show More