mirror of
https://github.com/langgenius/dify.git
synced 2026-03-15 04:07:01 +00:00
Compare commits
86 Commits
fix/vinext
...
move-trigg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a30a64d51b | ||
|
|
34ef10c818 | ||
|
|
26fedca865 | ||
|
|
2fd4e9e259 | ||
|
|
9e8a4c8a71 | ||
|
|
238497b7ab | ||
|
|
79a8747c1b | ||
|
|
e64f4d6039 | ||
|
|
6043ec4423 | ||
|
|
a600794370 | ||
|
|
573b4e41cb | ||
|
|
194c205ed3 | ||
|
|
7e1dc3c122 | ||
|
|
4203647c32 | ||
|
|
20e91990bf | ||
|
|
f38e8cca52 | ||
|
|
00eda73ad1 | ||
|
|
8b40a89add | ||
|
|
97776eabff | ||
|
|
fe561ef3d0 | ||
|
|
1104d35bbb | ||
|
|
724eaee77e | ||
|
|
4717168fe2 | ||
|
|
7fd3bd81ab | ||
|
|
0dcfac5b84 | ||
|
|
b66097b5f3 | ||
|
|
ceaa399351 | ||
|
|
dc50e4c4f2 | ||
|
|
157208ab1e | ||
|
|
3dabdc8282 | ||
|
|
ed5511ce28 | ||
|
|
68982f910e | ||
|
|
c43307dae1 | ||
|
|
b44b37518a | ||
|
|
b170eabaf3 | ||
|
|
e99628b76f | ||
|
|
60fe5e7f00 | ||
|
|
245f6b824d | ||
|
|
7d2054d4f4 | ||
|
|
07e19c0748 | ||
|
|
135b3a15a6 | ||
|
|
0045e387f5 | ||
|
|
44713a5c0f | ||
|
|
d5724aebde | ||
|
|
c59685748c | ||
|
|
36c1f4d506 | ||
|
|
31eba65fe0 | ||
|
|
72496a5847 | ||
|
|
8b16030d6b | ||
|
|
989db0e584 | ||
|
|
a0f0c97133 | ||
|
|
1505ec37ef | ||
|
|
7f09917b84 | ||
|
|
59367c6a1c | ||
|
|
f95e2acb65 | ||
|
|
8967b88584 | ||
|
|
6110c0a66c | ||
|
|
5640a2e47c | ||
|
|
60c8ee2a86 | ||
|
|
a97571156b | ||
|
|
084f2eb612 | ||
|
|
02f36bd9b5 | ||
|
|
2b1d1e9587 | ||
|
|
e85d20031e | ||
|
|
8a6a3ef0e4 | ||
|
|
2a3eb87326 | ||
|
|
5ff7d2c895 | ||
|
|
b2df0010ce | ||
|
|
f44cd70752 | ||
|
|
ee78ffcdb1 | ||
|
|
5d0c3d58ac | ||
|
|
a6e8e43883 | ||
|
|
0a320c63a0 | ||
|
|
1f4b6c84e0 | ||
|
|
d6721a1dd3 | ||
|
|
8b12389e19 | ||
|
|
0fad13370c | ||
|
|
904289bb1d | ||
|
|
8fe376848f | ||
|
|
d67f04f63e | ||
|
|
54637144c5 | ||
|
|
27f9cdedad | ||
|
|
5a5238062a | ||
|
|
7adcc17da0 | ||
|
|
1083f5c46a | ||
|
|
86b6868772 |
34
.github/actions/setup-web/action.yml
vendored
34
.github/actions/setup-web/action.yml
vendored
@@ -1,33 +1,13 @@
|
||||
name: Setup Web Environment
|
||||
description: Setup pnpm, Node.js, and install web dependencies.
|
||||
|
||||
inputs:
|
||||
node-version:
|
||||
description: Node.js version to use
|
||||
required: false
|
||||
default: "22"
|
||||
install-dependencies:
|
||||
description: Whether to install web dependencies after setting up Node.js
|
||||
required: false
|
||||
default: "true"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: ${{ inputs.node-version }}
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: ${{ inputs.install-dependencies == 'true' }}
|
||||
shell: bash
|
||||
run: pnpm --dir web install --frozen-lockfile
|
||||
node-version-file: "./web/.nvmrc"
|
||||
cache: true
|
||||
run-install: |
|
||||
- cwd: ./web
|
||||
args: ['--frozen-lockfile']
|
||||
|
||||
224
.github/dependabot.yml
vendored
224
.github/dependabot.yml
vendored
@@ -3,54 +3,210 @@ version: 2
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 2
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
python-dependencies:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 2
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
uv-dependencies:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/web"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
open-pull-requests-limit: 5
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
ignore:
|
||||
- dependency-name: "tailwind-merge"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "tailwindcss"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-syntax-highlighter"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-window"
|
||||
update-types: ["version-update:semver-major"]
|
||||
groups:
|
||||
lexical:
|
||||
patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
storybook:
|
||||
patterns:
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
eslint-group:
|
||||
patterns:
|
||||
- "*eslint*"
|
||||
npm-dependencies:
|
||||
github-actions-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
exclude-patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
- "*eslint*"
|
||||
|
||||
2
.github/workflows/anti-slop.yml
vendored
2
.github/workflows/anti-slop.yml
vendored
@@ -15,3 +15,5 @@ jobs:
|
||||
- uses: peakoss/anti-slop@v0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
close-pr: false
|
||||
failure-add-pr-labels: "needs-revision"
|
||||
|
||||
2
.github/workflows/api-tests.yml
vendored
2
.github/workflows/api-tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
6
.github/workflows/autofix.yml
vendored
6
.github/workflows/autofix.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@@ -102,13 +102,11 @@ jobs:
|
||||
- name: Setup web environment
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
node-version: "24"
|
||||
|
||||
- name: ESLint autofix
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd web
|
||||
pnpm eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
3
.github/workflows/main-ci.yml
vendored
3
.github/workflows/main-ci.yml
vendored
@@ -62,6 +62,9 @@ jobs:
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
with:
|
||||
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
|
||||
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
10
.github/workflows/style.yml
vendored
10
.github/workflows/style.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpm run lint:ci
|
||||
vp run lint:ci
|
||||
# pnpm run lint:report
|
||||
# continue-on-error: true
|
||||
|
||||
@@ -102,17 +102,17 @@ jobs:
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run lint:tss
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
run: vp run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run knip
|
||||
run: vp run knip
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@@ -50,8 +50,6 @@ jobs:
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
install-dependencies: "false"
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect_changes
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
24
.github/workflows/web-tests.yml
vendored
24
.github/workflows/web-tests.yml
vendored
@@ -2,6 +2,13 @@ name: Web Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
base_sha:
|
||||
required: false
|
||||
type: string
|
||||
head_sha:
|
||||
required: false
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -14,6 +21,8 @@ jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -34,7 +43,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
@@ -50,6 +59,8 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -59,6 +70,7 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
@@ -72,7 +84,13 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: pnpm vitest --merge-reports --coverage --silent=passed-only
|
||||
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
|
||||
|
||||
- name: Check app/components diff coverage
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/check-components-diff-coverage.mjs
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
@@ -429,4 +447,4 @@ jobs:
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run build
|
||||
run: vp run build
|
||||
|
||||
@@ -188,7 +188,6 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
|
||||
@@ -17,11 +17,6 @@ class WeaviateConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENABLED: bool = Field(
|
||||
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
|
||||
@@ -39,6 +39,7 @@ from . import (
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
@@ -184,6 +185,7 @@ __all__ = [
|
||||
"model_config",
|
||||
"model_providers",
|
||||
"models",
|
||||
"notification",
|
||||
"oauth",
|
||||
"oauth_server",
|
||||
"ops_trace",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
@@ -6,7 +8,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@@ -16,6 +18,7 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -277,3 +280,168 @@ class DeleteExploreBannerApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
subtitle: str | None = Field(default=None)
|
||||
body: str = Field(...)
|
||||
title_pic_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[c.model_dump() for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
90
api/controllers/console/notification.py
Normal file
90
api/controllers/console/notification.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
|
||||
|
||||
@console_ns.route("/notification")
|
||||
class NotificationApi(Resource):
|
||||
@console_ns.doc("get_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Return the active in-product notification for the current user "
|
||||
"in their interface language (falls back to English if unavailable). "
|
||||
"The notification is NOT marked as seen here; call POST /notification/dismiss "
|
||||
"when the user explicitly closes the modal."
|
||||
),
|
||||
responses={
|
||||
200: "Success — inspect should_show to decide whether to render the modal",
|
||||
401: "Unauthorized",
|
||||
},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
if not result.get("shouldShow"):
|
||||
return {"should_show": False, "notifications": []}, 200
|
||||
|
||||
lang = current_user.interface_language or _FALLBACK_LANG
|
||||
|
||||
notifications = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
notifications.append(
|
||||
{
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"should_show": bool(notifications), "notifications": notifications}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notification/dismiss")
|
||||
class NotificationDismissApi(Resource):
|
||||
@console_ns.doc("dismiss_notification")
|
||||
@console_ns.doc(
|
||||
description="Mark a notification as dismissed for the current user.",
|
||||
responses={200: "Success", 401: "Unauthorized"},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
account_id=str(current_user.id),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
@@ -43,6 +43,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
@@ -231,7 +232,7 @@ class AccountInitApi(Resource):
|
||||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_theme = "light"
|
||||
account.status = "active"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
@@ -138,20 +138,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
query = self.application_generate_entity.query
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
stop, new_inputs, new_query = self.handle_input_moderation(
|
||||
app_record=self._app,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id,
|
||||
):
|
||||
)
|
||||
if stop:
|
||||
return
|
||||
|
||||
self.application_generate_entity.inputs = new_inputs
|
||||
self.application_generate_entity.query = new_query
|
||||
system_inputs.query = new_query
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=self._app,
|
||||
message=self.message,
|
||||
query=query,
|
||||
query=new_query,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
):
|
||||
return
|
||||
@@ -163,7 +168,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
user_inputs=new_inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
@@ -240,10 +245,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
) -> tuple[bool, Mapping[str, Any], str]:
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
_, new_inputs, new_query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_generate_entity.app_config.tenant_id,
|
||||
app_generate_entity=app_generate_entity,
|
||||
@@ -253,9 +258,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
except ModerationError as e:
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
return True, inputs, query
|
||||
|
||||
return False
|
||||
return False, new_inputs, new_query
|
||||
|
||||
def handle_annotation_reply(
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
|
||||
@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
|
||||
@@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -30,8 +30,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -62,8 +64,6 @@ from dify_graph.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.graph_events.graph import GraphRunAbortedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
@@ -303,9 +303,11 @@ class WorkflowBasedAppRunner:
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
|
||||
@@ -316,7 +316,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
# Legacy provider fields kept for existing start-event consumers.
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
provider_id: str
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,7 +39,9 @@ class DatasetIndexToolCallbackHandler:
|
||||
source="app",
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=(
|
||||
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
||||
CreatorUserRole.ACCOUNT
|
||||
if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatorUserRole.END_USER
|
||||
),
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
||||
@@ -59,8 +59,6 @@ class DatasourcePluginProviderController(ABC):
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
@@ -193,7 +193,8 @@ class LLMGenerator:
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
@@ -279,7 +280,8 @@ class LLMGenerator:
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
error_step = "handle unexpected exception"
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
|
||||
@@ -628,10 +628,10 @@ class TraceTask:
|
||||
if not message_data:
|
||||
return {}
|
||||
conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
|
||||
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
|
||||
if not conversation_mode or len(conversation_mode) == 0:
|
||||
conversation_modes = db.session.scalars(conversation_mode_stmt).all()
|
||||
if not conversation_modes or len(conversation_modes) == 0:
|
||||
return {}
|
||||
conversation_mode = conversation_mode[0]
|
||||
conversation_mode = conversation_modes[0]
|
||||
created_at = message_data.created_at
|
||||
inputs = message_data.message
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
|
||||
@@ -157,6 +157,7 @@ class PluginInstallTaskPluginStatus(BaseModel):
|
||||
message: str = Field(description="The message of the install task.")
|
||||
icon: str = Field(description="The icon of the plugin.")
|
||||
labels: I18nObject = Field(description="The labels of the plugin.")
|
||||
source: str | None = Field(default=None, description="The installation source of the plugin")
|
||||
|
||||
|
||||
class PluginInstallTask(BasePluginEntity):
|
||||
|
||||
@@ -627,7 +627,7 @@ class ProviderManager:
|
||||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=quota.quota_type,
|
||||
quota_limit=0, # type: ignore
|
||||
quota_used=0,
|
||||
|
||||
@@ -74,7 +74,8 @@ class ExtractProcessor:
|
||||
else:
|
||||
suffix = ""
|
||||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
|
||||
# Generate a temporary filename under the created temp_dir and ensure the directory exists
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
|
||||
if return_text:
|
||||
|
||||
@@ -204,26 +204,61 @@ class WordExtractor(BaseExtractor):
|
||||
return " ".join(unique_content)
|
||||
|
||||
def _parse_cell_paragraph(self, paragraph, image_map):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
rel = paragraph.part.rels.get(image_id)
|
||||
if rel is None:
|
||||
continue
|
||||
# For external images, use image_id as key; for internal, use target_part
|
||||
if rel.is_external:
|
||||
if image_id in image_map:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
else:
|
||||
paragraph_content.append(run.text)
|
||||
paragraph_content: list[str] = []
|
||||
|
||||
for child in paragraph._element:
|
||||
tag = child.tag
|
||||
if tag == qn("w:hyperlink"):
|
||||
# Note: w:hyperlink elements may also use w:anchor for internal bookmarks.
|
||||
# This extractor intentionally only converts external links (HTTP/mailto, etc.)
|
||||
# that are backed by a relationship id (r:id) with rel.is_external == True.
|
||||
# Hyperlinks without such an external rel (including anchor-only bookmarks)
|
||||
# are left as plain text link_text.
|
||||
r_id = child.get(qn("r:id"))
|
||||
link_text_parts: list[str] = []
|
||||
for run_elem in child.findall(qn("w:r")):
|
||||
run = Run(run_elem, paragraph)
|
||||
if run.text:
|
||||
link_text_parts.append(run.text)
|
||||
link_text = "".join(link_text_parts).strip()
|
||||
if r_id:
|
||||
try:
|
||||
rel = paragraph.part.rels.get(r_id)
|
||||
if rel:
|
||||
target_ref = getattr(rel, "target_ref", None)
|
||||
if target_ref:
|
||||
parsed_target = urlparse(str(target_ref))
|
||||
if rel.is_external or parsed_target.scheme in ("http", "https", "mailto"):
|
||||
display_text = link_text or str(target_ref)
|
||||
link_text = f"[{display_text}]({target_ref})"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
|
||||
if link_text:
|
||||
paragraph_content.append(link_text)
|
||||
|
||||
elif tag == qn("w:r"):
|
||||
run = Run(child, paragraph)
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
image_id = blip.get(
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
|
||||
)
|
||||
if not image_id:
|
||||
continue
|
||||
rel = paragraph.part.rels.get(image_id)
|
||||
if rel is None:
|
||||
continue
|
||||
if rel.is_external:
|
||||
if image_id in image_map:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
else:
|
||||
if run.text:
|
||||
paragraph_content.append(run.text)
|
||||
|
||||
return "".join(paragraph_content).strip()
|
||||
|
||||
def parse_docx(self, docx_path):
|
||||
|
||||
@@ -83,6 +83,7 @@ from models.dataset import (
|
||||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.enums import CreatorUserRole
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -1009,7 +1010,7 @@ class DatasetRetrieval:
|
||||
content=json.dumps(contents),
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by_role=CreatorUserRole(user_from),
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
|
||||
@@ -146,7 +146,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
|
||||
# No sequence number generation needed anymore
|
||||
|
||||
db_model.type = domain_model.workflow_type
|
||||
from models.workflow import WorkflowType as ModelWorkflowType
|
||||
|
||||
db_model.type = ModelWorkflowType(domain_model.workflow_type.value)
|
||||
db_model.version = domain_model.workflow_version
|
||||
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
|
||||
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||
|
||||
@@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
:param credential_type: the type of the credential, as CredentialType or str; str values
|
||||
are normalized via CredentialType.of and may raise ValueError for invalid values.
|
||||
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
|
||||
empty list for CredentialType.UNAUTHORIZED or missing schemas.
|
||||
|
||||
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
|
||||
Raises ValueError for invalid credential types.
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
if isinstance(credential_type, str):
|
||||
credential_type = CredentialType.of(credential_type)
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
|
||||
@@ -137,6 +137,7 @@ class ToolFileManager:
|
||||
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
|
||||
@@ -19,9 +19,10 @@ from core.trigger.debug.events import (
|
||||
build_plugin_pool_key,
|
||||
build_webhook_pool_key,
|
||||
)
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
|
||||
from libs.schedule_utils import calculate_next_run_at
|
||||
@@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC):
|
||||
app_id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
node_config: Mapping[str, Any]
|
||||
node_config: NodeConfigDict
|
||||
node_id: str
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str):
|
||||
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.app_id = app_id
|
||||
@@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
def poll(self) -> TriggerDebugEvent | None:
|
||||
from services.trigger.trigger_service import TriggerService
|
||||
|
||||
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {}))
|
||||
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True)
|
||||
provider_id = TriggerProviderID(plugin_trigger_data.provider_id)
|
||||
pool_key: str = build_plugin_pool_key(
|
||||
name=plugin_trigger_data.event_name,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -22,7 +22,9 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
@@ -31,26 +33,18 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
|
||||
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from dify_graph.nodes.code.entities import CodeLanguage
|
||||
from dify_graph.nodes.code.limits import CodeNodeLimits
|
||||
from dify_graph.nodes.datasource import DatasourceNode
|
||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
|
||||
from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
|
||||
from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import build_http_request_config
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@@ -60,6 +54,9 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
@@ -157,178 +154,128 @@ class DifyNodeFactory(NodeFactory):
|
||||
return DifyRunContext.model_validate(raw_ctx)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
:raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config["id"]
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
NodeType.CODE: lambda: {
|
||||
"code_executor": self._code_executor,
|
||||
"code_limits": self._code_limits,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: lambda: {
|
||||
"template_renderer": self._template_renderer,
|
||||
"max_output_length": self._template_transform_max_output_length,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: lambda: {
|
||||
"http_request_config": self._http_request_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
|
||||
"file_manager": self._http_request_file_manager,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: lambda: {
|
||||
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: lambda: {
|
||||
"index_processor": IndexProcessor(),
|
||||
"summary_index_service": SummaryIndex(),
|
||||
},
|
||||
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
NodeType.DATASOURCE: lambda: {
|
||||
"datasource_manager": DatasourceManager,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
|
||||
"rag_retrieval": self._rag_retrieval,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: lambda: {
|
||||
"unstructured_api_config": self._document_extractor_unstructured_api_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=True,
|
||||
),
|
||||
NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
include_http_client=False,
|
||||
),
|
||||
NodeType.TOOL: lambda: {
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
)
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config["data"]
|
||||
try:
|
||||
node_type = NodeType(node_data["type"])
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_data['type']}")
|
||||
@staticmethod
|
||||
def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData:
|
||||
"""
|
||||
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
|
||||
"""
|
||||
return node_class.validate_node_data(node_data)
|
||||
|
||||
# Get node class
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
|
||||
# Create node instance
|
||||
if node_type == NodeType.CODE:
|
||||
return CodeNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
max_output_length=self._template_transform_max_output_length,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
return HttpRequestNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
http_request_config=self._http_request_config,
|
||||
http_client=self._http_request_http_client,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HUMAN_INPUT:
|
||||
return HumanInputNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_INDEX:
|
||||
return KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
index_processor=IndexProcessor(),
|
||||
summary_index_service=SummaryIndex(),
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
return DatasourceNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
datasource_manager=DatasourceManager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DOCUMENT_EXTRACTOR:
|
||||
return DocumentExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TOOL:
|
||||
return ToolNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
def _build_llm_compatible_node_init_kwargs(
|
||||
self,
|
||||
*,
|
||||
node_class: type[Node],
|
||||
node_data: BaseNodeData,
|
||||
include_http_client: bool,
|
||||
) -> dict[str, object]:
|
||||
validated_node_data = cast(
|
||||
LLMCompatibleNodeData,
|
||||
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
|
||||
)
|
||||
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
|
||||
node_init_kwargs: dict[str, object] = {
|
||||
"credentials_provider": self._llm_credentials_provider,
|
||||
"model_factory": self._llm_model_factory,
|
||||
"model_instance": model_instance,
|
||||
"memory": self._build_memory_for_llm_node(
|
||||
node_data=validated_node_data,
|
||||
model_instance=model_instance,
|
||||
),
|
||||
}
|
||||
if include_http_client:
|
||||
node_init_kwargs["http_client"] = self._http_request_http_client
|
||||
return node_init_kwargs
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
|
||||
node_data_model = ModelConfig.model_validate(node_data["model"])
|
||||
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
|
||||
node_data_model = node_data.model
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
@@ -364,14 +311,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
def _build_memory_for_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LLMCompatibleNodeData,
|
||||
model_instance: ModelInstance,
|
||||
) -> PromptMessageMemory | None:
|
||||
raw_memory_config = node_data.get("memory")
|
||||
if raw_memory_config is None:
|
||||
if node_data.memory is None:
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
@@ -381,6 +326,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
node_data_memory=node_memory,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
1
api/core/workflow/nodes/__init__.py
Normal file
1
api/core/workflow/nodes/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core-owned workflow node packages."""
|
||||
30
api/core/workflow/nodes/node_mapping.py
Normal file
30
api/core/workflow/nodes/node_mapping.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Node mapping for workflow execution.
|
||||
|
||||
`core.workflow` owns the trigger node implementations, while the remaining node
|
||||
implementations still live under `dify_graph`. This module imports the
|
||||
core-owned node packages first, then asks the shared `Node` registry to load the
|
||||
rest of the workflow nodes from `dify_graph`.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
from collections.abc import Mapping
|
||||
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
|
||||
def _register_core_workflow_nodes() -> None:
|
||||
import core.workflow.nodes as workflow_nodes_pkg
|
||||
|
||||
for _, modname, _ in pkgutil.walk_packages(workflow_nodes_pkg.__path__, workflow_nodes_pkg.__name__ + "."):
|
||||
if modname == "core.workflow.nodes.node_mapping":
|
||||
continue
|
||||
importlib.import_module(modname)
|
||||
|
||||
|
||||
_register_core_workflow_nodes()
|
||||
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
@@ -4,13 +4,17 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.trigger.entities.entities import EventParameter
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
from .exc import TriggerEventParameterError
|
||||
|
||||
|
||||
class TriggerEventNodeData(BaseNodeData):
|
||||
"""Plugin trigger node data"""
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_PLUGIN
|
||||
|
||||
class TriggerEventInput(BaseModel):
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
@@ -38,8 +42,6 @@ class TriggerEventNodeData(BaseNodeData):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return type
|
||||
|
||||
title: str
|
||||
desc: str | None = None
|
||||
plugin_id: str = Field(..., description="Plugin ID")
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
event_name: str = Field(..., description="Event name")
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Mapping
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType
|
||||
from dify_graph.graph_events import NodeRunStartedEvent
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
||||
@@ -32,6 +33,11 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def customize_start_event(self, event: NodeRunStartedEvent) -> None:
|
||||
provider_id = self.node_data.provider_id
|
||||
event.provider_id = provider_id
|
||||
event.extras["provider_id"] = provider_id
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the plugin trigger node.
|
||||
3
api/core/workflow/nodes/trigger_schedule/__init__.py
Normal file
3
api/core/workflow/nodes/trigger_schedule/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .trigger_schedule_node import TriggerScheduleNode
|
||||
|
||||
__all__ = ["TriggerScheduleNode"]
|
||||
@@ -2,7 +2,8 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class TriggerScheduleNodeData(BaseNodeData):
|
||||
@@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData):
|
||||
Trigger Schedule Node Data
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_SCHEDULE
|
||||
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
|
||||
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
|
||||
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
|
||||
@@ -1,4 +1,4 @@
|
||||
from dify_graph.nodes.base.exc import BaseNodeError
|
||||
from dify_graph.entities.exc import BaseNodeError
|
||||
|
||||
|
||||
class ScheduleNodeError(BaseNodeError):
|
||||
@@ -5,7 +5,8 @@ from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionSta
|
||||
from dify_graph.enums import NodeExecutionType, NodeType
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.trigger_schedule.entities import TriggerScheduleNodeData
|
||||
|
||||
from .entities import TriggerScheduleNodeData
|
||||
|
||||
|
||||
class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
|
||||
132
api/core/workflow/nodes/trigger_webhook/entities.py
Normal file
132
api/core/workflow/nodes/trigger_webhook/entities.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
}
|
||||
)
|
||||
|
||||
_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
}
|
||||
)
|
||||
|
||||
_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES
|
||||
|
||||
_WEBHOOK_BODY_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.FILE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Method(StrEnum):
|
||||
GET = "get"
|
||||
POST = "post"
|
||||
HEAD = "head"
|
||||
PATCH = "patch"
|
||||
PUT = "put"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class ContentType(StrEnum):
|
||||
JSON = "application/json"
|
||||
FORM_DATA = "multipart/form-data"
|
||||
FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
TEXT = "text/plain"
|
||||
BINARY = "application/octet-stream"
|
||||
|
||||
|
||||
class WebhookParameter(BaseModel):
|
||||
"""Parameter definition for headers or query params."""
|
||||
|
||||
name: str
|
||||
type: SegmentType = SegmentType.STRING
|
||||
required: bool = False
|
||||
|
||||
@field_validator("type", mode="after")
|
||||
@classmethod
|
||||
def validate_type(cls, v: SegmentType) -> SegmentType:
|
||||
if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook parameter type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class WebhookBodyParameter(BaseModel):
|
||||
"""Body parameter with type information."""
|
||||
|
||||
name: str
|
||||
type: SegmentType = SegmentType.STRING
|
||||
required: bool = False
|
||||
|
||||
@field_validator("type", mode="after")
|
||||
@classmethod
|
||||
def validate_type(cls, v: SegmentType) -> SegmentType:
|
||||
if v not in _WEBHOOK_BODY_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook body parameter type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class WebhookData(BaseNodeData):
|
||||
"""
|
||||
Webhook Node Data.
|
||||
"""
|
||||
|
||||
class SyncMode(StrEnum):
|
||||
SYNC = "async" # only support
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_WEBHOOK
|
||||
method: Method = Method.GET
|
||||
content_type: ContentType = Field(default=ContentType.JSON)
|
||||
headers: Sequence[WebhookParameter] = Field(default_factory=list)
|
||||
params: Sequence[WebhookParameter] = Field(default_factory=list) # query parameters
|
||||
body: Sequence[WebhookBodyParameter] = Field(default_factory=list)
|
||||
|
||||
@field_validator("method", mode="before")
|
||||
@classmethod
|
||||
def normalize_method(cls, v) -> str:
|
||||
"""Normalize HTTP method to lowercase to support both uppercase and lowercase input."""
|
||||
if isinstance(v, str):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
@field_validator("headers", mode="after")
|
||||
@classmethod
|
||||
def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
|
||||
for param in v:
|
||||
if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook header parameter type: {param.type}")
|
||||
return v
|
||||
|
||||
@field_validator("params", mode="after")
|
||||
@classmethod
|
||||
def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]:
|
||||
for param in v:
|
||||
if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook query parameter type: {param.type}")
|
||||
return v
|
||||
|
||||
status_code: int = 200 # Expected status code for response
|
||||
response_body: str = "" # Template for response body
|
||||
|
||||
# Webhook specific fields (not from client data, set internally)
|
||||
webhook_id: str | None = None # Set when webhook trigger is created
|
||||
timeout: int = 30 # Timeout in seconds to wait for webhook response
|
||||
@@ -1,4 +1,4 @@
|
||||
from dify_graph.nodes.base.exc import BaseNodeError
|
||||
from dify_graph.entities.exc import BaseNodeError
|
||||
|
||||
|
||||
class WebhookNodeError(BaseNodeError):
|
||||
@@ -152,7 +152,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs[param_name] = raw_data
|
||||
continue
|
||||
|
||||
if param_type == "file":
|
||||
if param_type == SegmentType.FILE:
|
||||
# Get File object (already processed by webhook controller)
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
@@ -9,9 +9,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph import Graph
|
||||
@@ -23,7 +24,6 @@ from dify_graph.graph_engine.protocols.command_channel import CommandChannel
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
@@ -212,7 +212,7 @@ class WorkflowEntry:
|
||||
node_config_data = node_config["data"]
|
||||
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data["type"])
|
||||
node_type = node_config_data.type
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -234,8 +234,7 @@ class WorkflowEntry:
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
typed_node_config = cast(dict[str, object], node_config)
|
||||
node = cast(Any, node_factory).create_node(typed_node_config)
|
||||
node = node_factory.create_node(node_config)
|
||||
node_cls = type(node)
|
||||
|
||||
try:
|
||||
@@ -371,10 +370,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config: NodeConfigDict = {
|
||||
"id": node_id,
|
||||
"data": cast(NodeConfigData, node_data),
|
||||
}
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
176
api/dify_graph/entities/base_node_data.py
Normal file
176
api/dify_graph/entities/base_node_data.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from dify_graph.entities.exc import DefaultValueTypeError
|
||||
from dify_graph.enums import ErrorStrategy, NodeType
|
||||
|
||||
# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`.
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
|
||||
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
|
||||
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
|
||||
# and persisted templates/workflows also carry undeclared compatibility keys such as
|
||||
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
|
||||
# here until graph parsing becomes discriminated by node type or those legacy payloads
|
||||
# are normalized.
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
type: NodeType
|
||||
title: str = ""
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = Field(default_factory=RetryConfig)
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""
|
||||
Dict-style access without calling model_dump() on every lookup.
|
||||
Prefer using model fields and Pydantic's extra storage.
|
||||
"""
|
||||
# First, check declared model fields
|
||||
if key in self.__class__.model_fields:
|
||||
return getattr(self, key)
|
||||
|
||||
# Then, check undeclared compatibility fields stored in Pydantic's extra dict.
|
||||
extras = getattr(self, "__pydantic_extra__", None)
|
||||
if extras is None:
|
||||
extras = getattr(self, "model_extra", None)
|
||||
if extras is not None and key in extras:
|
||||
return extras[key]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Dict-style .get() without calling model_dump() on every lookup.
|
||||
"""
|
||||
if key in self.__class__.model_fields:
|
||||
return getattr(self, key)
|
||||
|
||||
extras = getattr(self, "__pydantic_extra__", None)
|
||||
if extras is None:
|
||||
extras = getattr(self, "model_extra", None)
|
||||
if extras is not None and key in extras:
|
||||
return extras.get(key, default)
|
||||
|
||||
return default
|
||||
@@ -4,21 +4,20 @@ import sys
|
||||
|
||||
from pydantic import TypeAdapter, with_config
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigData(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
data: NodeConfigData
|
||||
# This is the permissive raw graph boundary. Node factories re-validate `data`
|
||||
# with the concrete `NodeData` subtype after resolving the node implementation.
|
||||
data: BaseNodeData
|
||||
|
||||
|
||||
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Protocol, cast, final
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from libs.typing import is_str
|
||||
|
||||
@@ -34,7 +34,8 @@ class NodeFactory(Protocol):
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
:raises ValueError: if node type is unknown or no implementation exists for the resolved version
|
||||
:raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -115,10 +116,7 @@ class Graph:
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid]["data"]
|
||||
node_type = node_data["type"]
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if NodeType(node_type).is_start_node:
|
||||
if node_data.type.is_start_node:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
@@ -203,6 +201,23 @@ class Graph:
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@staticmethod
|
||||
def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Remove editor-only nodes before `NodeConfigDict` validation.
|
||||
|
||||
Persisted note widgets use a top-level `type == "custom-note"` but leave
|
||||
`data.type` empty because they are never executable graph nodes. Filter
|
||||
them while configs are still raw dicts so Pydantic does not validate
|
||||
their placeholder payloads against `BaseNodeData.type: NodeType`.
|
||||
"""
|
||||
filtered_node_configs: list[dict[str, object]] = []
|
||||
for node_config in node_configs:
|
||||
if node_config.get("type", "") == "custom-note":
|
||||
continue
|
||||
filtered_node_configs.append(dict(node_config))
|
||||
return filtered_node_configs
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
@@ -302,13 +317,13 @@ class Graph:
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
node_configs = cls._filter_canvas_only_nodes(node_configs)
|
||||
node_configs = _ListNodeConfigDict.validate_python(node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
|
||||
@@ -15,8 +15,9 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
predecessor_node_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
# Legacy provider fields kept for existing start-event consumers.
|
||||
provider_type: str = ""
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
@@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
|
||||
@@ -4,7 +4,8 @@ class InvokeError(ValueError):
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
@@ -282,7 +282,8 @@ class ModelProviderFactory:
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
simple_provider_schema.models.extend(all_model_type_models)
|
||||
if model_type:
|
||||
simple_provider_schema.models = all_model_type_models
|
||||
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
|
||||
@@ -374,12 +374,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AgentNodeData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
result: dict[str, Any] = {}
|
||||
typed_node_data = node_data
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
|
||||
@@ -5,10 +5,12 @@ from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.AGENT
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
|
||||
@@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: AnswerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AnswerNodeData.model_validate(node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
@@ -3,7 +3,8 @@ from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
@@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData):
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.ANSWER
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
@@ -6,6 +6,5 @@ __all__ = [
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
]
|
||||
|
||||
@@ -1,31 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from dify_graph.enums import ErrorStrategy
|
||||
|
||||
from .exc import DefaultValueTypeError
|
||||
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
@@ -76,112 +57,6 @@ class OutputVariableEntity(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, ge
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
@@ -62,8 +64,6 @@ from dify_graph.node_events import (
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
|
||||
@@ -153,11 +153,11 @@ class Node(Generic[NodeDataT]):
|
||||
Later, in __init__:
|
||||
::
|
||||
|
||||
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
config["data"] ──► _node_data_type.model_validate(..., from_attributes=True)
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
@@ -179,7 +179,7 @@ class Node(Generic[NodeDataT]):
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under dify_graph.nodes.*
|
||||
# Only register production node implementations defined under dify_graph.nodes.*.
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
@@ -241,7 +241,7 @@ class Node(Generic[NodeDataT]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
@@ -254,26 +254,29 @@ class Node(Generic[NodeDataT]):
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
node_id = config["id"]
|
||||
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
self._node_data = self.validate_node_data(config["data"])
|
||||
|
||||
self.post_init()
|
||||
|
||||
@classmethod
|
||||
def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT:
|
||||
"""Validate shared graph node payloads against the subclass-declared NodeData model."""
|
||||
return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True))
|
||||
|
||||
def post_init(self) -> None:
|
||||
"""Optional hook for subclasses requiring extra initialization."""
|
||||
return
|
||||
|
||||
def customize_start_event(self, event: NodeRunStartedEvent) -> None:
|
||||
"""Optional hook for subclasses to attach start-event metadata or extras."""
|
||||
return
|
||||
|
||||
@property
|
||||
def graph_init_params(self) -> GraphInitParams:
|
||||
return self._graph_init_params
|
||||
@@ -342,9 +345,6 @@ class Node(Generic[NodeDataT]):
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
@@ -383,14 +383,6 @@ class Node(Generic[NodeDataT]):
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
from dify_graph.nodes.agent.agent_node import AgentNode
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData
|
||||
|
||||
@@ -400,6 +392,8 @@ class Node(Generic[NodeDataT]):
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
self.customize_start_event(start_event)
|
||||
|
||||
# ===
|
||||
yield start_event
|
||||
|
||||
@@ -442,7 +436,7 @@ class Node(Generic[NodeDataT]):
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
@@ -480,13 +474,12 @@ class Node(Generic[NodeDataT]):
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
node_id = config["id"]
|
||||
node_data = cls.validate_node_data(config["data"])
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -496,7 +489,7 @@ class Node(Generic[NodeDataT]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: NodeDataT,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
@@ -531,6 +524,7 @@ class Node(Generic[NodeDataT]):
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under dify_graph.nodes so subclasses register themselves on import.
|
||||
Higher-level packages may register additional nodes before calling this helper.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
|
||||
@@ -3,6 +3,7 @@ from decimal import Decimal
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
@@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: CodeNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.CODE
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult, StreamCompletedEvent
|
||||
@@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
@@ -181,7 +182,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: DatasourceNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -190,11 +191,10 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
typed_node_data = DatasourceNodeData.model_validate(node_data)
|
||||
result = {}
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
input = node_data.datasource_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from dify_graph.nodes.base.entities import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
@@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel):
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
type: NodeType = NodeType.DATASOURCE
|
||||
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.DOCUMENT_EXTRACTOR
|
||||
variable_selector: Sequence[str]
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod, file_manager
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@@ -54,7 +55,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -136,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: DocumentExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
|
||||
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
@@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData):
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.END
|
||||
outputs: list[OutputVariableEntity]
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ import charset_normalizer
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
|
||||
|
||||
@@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.HTTP_REQUEST
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
|
||||
@@ -3,6 +3,7 @@ import mimetypes
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -163,18 +164,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: HttpRequestNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = HttpRequestNodeData.model_validate(node_data)
|
||||
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
|
||||
if typed_node_data.body:
|
||||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
|
||||
if node_data.body:
|
||||
body_type = node_data.body.type
|
||||
data = node_data.body.data
|
||||
match body_type:
|
||||
case "none":
|
||||
pass
|
||||
|
||||
@@ -10,7 +10,8 @@ from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.consts import SELECTORS_LENGTH
|
||||
@@ -71,8 +72,8 @@ class EmailDeliveryConfig(BaseModel):
|
||||
body: str
|
||||
debug_mode: bool = False
|
||||
|
||||
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
|
||||
if not user_id:
|
||||
def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig":
|
||||
if user_id is None:
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
|
||||
@@ -140,7 +141,7 @@ def apply_debug_email_recipient(
|
||||
method: DeliveryChannelConfig,
|
||||
*,
|
||||
enabled: bool,
|
||||
user_id: str,
|
||||
user_id: str | None,
|
||||
) -> DeliveryChannelConfig:
|
||||
if not enabled:
|
||||
return method
|
||||
@@ -148,7 +149,7 @@ def apply_debug_email_recipient(
|
||||
return method
|
||||
if not method.config.debug_mode:
|
||||
return method
|
||||
debug_config = method.config.with_debug_recipient(user_id or "")
|
||||
debug_config = method.config.with_debug_recipient(user_id)
|
||||
return method.model_copy(update={"config": debug_config})
|
||||
|
||||
|
||||
@@ -214,6 +215,7 @@ class UserAction(BaseModel):
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Human Input node data."""
|
||||
|
||||
type: NodeType = NodeType.HUMAN_INPUT
|
||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
||||
form_content: str = ""
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import (
|
||||
@@ -63,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository,
|
||||
@@ -348,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: HumanInputNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input default values.
|
||||
@@ -357,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input default values
|
||||
"""
|
||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
return node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
|
||||
@@ -2,7 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
|
||||
|
||||
@@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData):
|
||||
If Else Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.IF_ELSE
|
||||
|
||||
class Case(BaseModel):
|
||||
"""
|
||||
Case entity representing a single logical condition group
|
||||
|
||||
@@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IfElseNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in typed_node_data.cases or []:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
for case in node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
class ErrorHandleMode(StrEnum):
|
||||
@@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.ITERATION
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
@@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData):
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.ITERATION_START
|
||||
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -460,21 +461,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: IterationNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IterationNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
node_config_data = node.get("data", {})
|
||||
if node_config_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
@@ -490,14 +488,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
# Get node class
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Literal, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
@@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
Knowledge index Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-index"
|
||||
type: NodeType = NodeType.KNOWLEDGE_INDEX
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
@@ -30,7 +31,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
index_processor: IndexProcessorProtocol,
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-retrieval"
|
||||
type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
query_variable_selector: list[str] | None | str = None
|
||||
query_attachment_selector: list[str] | None | str = None
|
||||
dataset_ids: list[str]
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@@ -49,7 +50,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
@@ -301,15 +302,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
if typed_node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
if typed_node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||
if node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
if node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector
|
||||
return variable_mapping
|
||||
|
||||
@@ -3,7 +3,8 @@ from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
@@ -62,6 +63,7 @@ class ExtractConfig(BaseModel):
|
||||
|
||||
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.LIST_OPERATOR
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderByConfig
|
||||
|
||||
@@ -4,8 +4,9 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
@@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.LLM
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
|
||||
@@ -21,6 +21,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -121,7 +122,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
@@ -954,14 +955,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LLMNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
prompt_template = node_data.prompt_template
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
@@ -979,7 +977,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
memory = typed_node_data.memory
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
@@ -987,16 +985,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if typed_node_data.context.enabled:
|
||||
variable_mapping["#context#"] = typed_node_data.context.variable_selector
|
||||
if node_data.context.enabled:
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if typed_node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
|
||||
if typed_node_data.memory:
|
||||
if node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
|
||||
|
||||
if typed_node_data.prompt_config:
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
@@ -1009,7 +1007,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
break
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@@ -39,6 +41,7 @@ class LoopVariableData(BaseModel):
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
type: NodeType = NodeType.LOOP
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
@@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData):
|
||||
Loop Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.LOOP_START
|
||||
|
||||
|
||||
class LoopEndNodeData(BaseNodeData):
|
||||
@@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData):
|
||||
Loop End Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.LOOP_END
|
||||
|
||||
|
||||
class LoopState(BaseLoopState):
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LoopNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
@@ -320,14 +318,15 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
# Get node class
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
@@ -342,7 +341,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in typed_node_data.loop_variables or []:
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
|
||||
@@ -8,7 +8,8 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
@@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
Parameter Extractor Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.PARAMETER_EXTRACTOR
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -837,15 +838,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: ParameterExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
|
||||
|
||||
if typed_node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
|
||||
if node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
|
||||
for selector in selectors:
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
@@ -11,6 +12,7 @@ class ClassConfig(BaseModel):
|
||||
|
||||
|
||||
class QuestionClassifierNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.QUESTION_CLASSIFIER
|
||||
query_variable_selector: list[str]
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.model_manager import ModelInstance
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -62,7 +63,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -251,16 +252,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: QuestionClassifierNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {"query": typed_node_data.query_variable_selector}
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors: list[VariableSelector] = []
|
||||
if typed_node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
|
||||
|
||||
@@ -2,7 +2,8 @@ from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
@@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData):
|
||||
Start Node Data
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.START
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from dify_graph.nodes.base import BaseNodeData
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
@@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData):
|
||||
Template Transform Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.TEMPLATE_TRANSFORM
|
||||
variables: list[VariableSelector]
|
||||
template: str
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user