Compare commits

..

31 Commits

Author SHA1 Message Date
Stephen Zhou
f8a3d1cddc clean 2026-03-27 19:18:08 +08:00
Stephen Zhou
e93cb210f8 tweaks 2026-03-27 19:10:50 +08:00
Stephen Zhou
21a8dedb5e fix 2026-03-27 19:04:03 +08:00
Stephen Zhou
c5bb95ce00 update 2026-03-27 18:48:30 +08:00
Stephen Zhou
ecbc8ed3e6 update 2026-03-27 18:44:11 +08:00
Stephen Zhou
7e1d15386f update 2026-03-27 18:41:43 +08:00
Stephen Zhou
dedc6e7e2a tweaks 2026-03-27 18:23:54 +08:00
Stephen Zhou
36083c5316 tweaks 2026-03-27 18:19:46 +08:00
Stephen Zhou
78e1c69f64 add E2E_SLOW_MO 2026-03-27 18:17:59 +08:00
Stephen Zhou
e0256147e4 clean 2026-03-27 18:12:01 +08:00
Stephen Zhou
8b0a1dbe82 clean 2026-03-27 18:09:49 +08:00
Stephen Zhou
d92f60d942 try cucumber 2026-03-27 18:06:51 +08:00
Stephen Zhou
4cc9e73d4a test: init e2e 2026-03-27 17:13:12 +08:00
Xiyuan Chen
5a8a68cab8 feat: enterprise otel exporter (#33138)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Skip Duplicate Checks (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Run API Tests (push) Has been cancelled
Main CI Pipeline / Skip API Tests (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Run Web Tests (push) Has been cancelled
Main CI Pipeline / Skip Web Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / Run VDB Tests (push) Has been cancelled
Main CI Pipeline / Skip VDB Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / Run DB Migration Test (push) Has been cancelled
Main CI Pipeline / Skip DB Migration Test (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-27 07:56:31 +00:00
wangxiaolei
689761bfcb feat: return correct dify-plugin-daemon error message (#34171) 2026-03-27 06:02:29 +00:00
Stephen Zhou
2394e45ec7 ci: skip duplicate actions (#34168) 2026-03-27 02:44:57 +00:00
1Ckpwee
01e6a3a9d9 chore(ci): remove Python 3.11 from CI test workflows (#34164) 2026-03-27 02:41:19 +00:00
Stephen Zhou
07f4950cb3 test: use happy dom (#34154)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Run API Tests (push) Has been cancelled
Main CI Pipeline / Skip API Tests (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Run Web Tests (push) Has been cancelled
Main CI Pipeline / Skip Web Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / Run VDB Tests (push) Has been cancelled
Main CI Pipeline / Skip VDB Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / Run DB Migration Test (push) Has been cancelled
Main CI Pipeline / Skip DB Migration Test (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-03-27 01:46:19 +00:00
非法操作
368896d84d feat: add copy/delete to multi nodes context menu (#34138) 2026-03-27 01:20:39 +00:00
YBoy
408f650b0c test: migrate auth integration tests to testcontainers (#34089)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-03-26 23:25:36 +00:00
dependabot[bot]
7c2e1fa3e2 chore(deps): bump brace-expansion from 5.0.4 to 5.0.5 in /sdks/nodejs-client (#34159)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-26 23:21:18 +00:00
YBoy
1da66b9a8c test: migrate api token service tests to testcontainers (#34148)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-26 21:02:09 +00:00
dependabot[bot]
4953762f4e chore(deps): bump requests from 2.32.5 to 2.33.0 in /api (#34116)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-26 20:59:35 +00:00
YBoy
97764c4a57 test: migrate plugin service tests to testcontainers (#34098)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-03-26 20:36:12 +00:00
tmimmanuel
2ea85d3ba2 refactor: use EnumText for model_type and WorkflowNodeExecution.status (#34093)
Co-authored-by: Krishna Chaitanya <krishnabkc15@gmail.com>
2026-03-26 20:34:44 +00:00
dependabot[bot]
1f11300175 chore(deps-dev): bump nltk from 3.9.3 to 3.9.4 in /api (#34117)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-26 20:31:40 +00:00
YBoy
f317db525f test: migrate api key auth service tests to testcontainers (#34147)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-26 20:31:18 +00:00
YBoy
3fa0538f72 test: migrate human input delivery test service tests to testcontainers (#34092)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-26 20:29:20 +00:00
99
fcfc96ca05 chore: remove stale mypy suppressions and align dataset service tests (#34130)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Run API Tests (push) Has been cancelled
Main CI Pipeline / Skip API Tests (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Run Web Tests (push) Has been cancelled
Main CI Pipeline / Skip Web Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / Run VDB Tests (push) Has been cancelled
Main CI Pipeline / Skip VDB Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / Run DB Migration Test (push) Has been cancelled
Main CI Pipeline / Skip DB Migration Test (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-26 12:34:44 +00:00
-LAN-
69c2b422de chore: Keep main CI lane checks stable when skipped (#34143)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Run API Tests (push) Has been cancelled
Main CI Pipeline / Skip API Tests (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Run Web Tests (push) Has been cancelled
Main CI Pipeline / Skip Web Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / Run VDB Tests (push) Has been cancelled
Main CI Pipeline / Skip VDB Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / Run DB Migration Test (push) Has been cancelled
Main CI Pipeline / Skip DB Migration Test (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-26 09:29:41 +00:00
-LAN-
496baa9335 chore(api): remove backend utcnow usage (#34131) 2026-03-26 08:51:49 +00:00
203 changed files with 16604 additions and 11112 deletions

View File

@@ -25,7 +25,6 @@ jobs:
strategy:
matrix:
python-version:
- "3.11"
- "3.12"
steps:

View File

@@ -10,6 +10,7 @@ on:
branches: ["main"]
permissions:
actions: write
contents: write
pull-requests: write
checks: write
@@ -20,12 +21,28 @@ concurrency:
cancel-in-progress: true
jobs:
pre_job:
name: Skip Duplicate Checks
runs-on: ubuntu-latest
outputs:
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
steps:
- id: skip_check
continue-on-error: true
uses: fkirc/skip-duplicate-actions@f75f66ce1886f00957d99748a42c724f4330bdcf # v5.3.1
with:
cancel_others: 'true'
concurrent_skipping: same_content_newer
# Check which paths were changed to determine which tests to run
check-changes:
name: Check Changed Files
needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true'
runs-on: ubuntu-latest
outputs:
api-changed: ${{ steps.changes.outputs.api }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
@@ -43,6 +60,16 @@ jobs:
- 'web/**'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
e2e:
- 'api/**'
- 'api/pyproject.toml'
- 'api/uv.lock'
- 'e2e/**'
- 'web/**'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'
- '.github/workflows/web-e2e.yml'
- '.github/actions/setup-web/**'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'
@@ -53,33 +80,306 @@ jobs:
- 'api/migrations/**'
- '.github/workflows/db-migration-test.yml'
# Run tests in parallel
api-tests:
name: API Tests
needs: check-changes
if: needs.check-changes.outputs.api-changed == 'true'
# Run tests in parallel while always emitting stable required checks.
api-tests-run:
name: Run API Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed == 'true'
uses: ./.github/workflows/api-tests.yml
secrets: inherit
web-tests:
name: Web Tests
needs: check-changes
if: needs.check-changes.outputs.web-changed == 'true'
api-tests-skip:
name: Skip API Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
runs-on: ubuntu-latest
steps:
- name: Report skipped API tests
run: echo "No API-related changes detected; skipping API tests."
api-tests:
name: API Tests
if: ${{ always() }}
needs:
- pre_job
- check-changes
- api-tests-run
- api-tests-skip
runs-on: ubuntu-latest
steps:
- name: Finalize API Tests status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.api-changed }}
RUN_RESULT: ${{ needs.api-tests-run.result }}
SKIP_RESULT: ${{ needs.api-tests-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "API tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "API tests ran successfully."
exit 0
fi
echo "API tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "API tests were skipped because no API-related files changed."
exit 0
fi
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
web-tests-run:
name: Run Web Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
secrets: inherit
web-tests-skip:
name: Skip Web Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
runs-on: ubuntu-latest
steps:
- name: Report skipped web tests
run: echo "No web-related changes detected; skipping web tests."
web-tests:
name: Web Tests
if: ${{ always() }}
needs:
- pre_job
- check-changes
- web-tests-run
- web-tests-skip
runs-on: ubuntu-latest
steps:
- name: Finalize Web Tests status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.web-changed }}
RUN_RESULT: ${{ needs.web-tests-run.result }}
SKIP_RESULT: ${{ needs.web-tests-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "Web tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "Web tests ran successfully."
exit 0
fi
echo "Web tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "Web tests were skipped because no web-related files changed."
exit 0
fi
echo "Web tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
web-e2e-run:
name: Run Web Full-Stack E2E
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed == 'true'
uses: ./.github/workflows/web-e2e.yml
web-e2e-skip:
name: Skip Web Full-Stack E2E
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
runs-on: ubuntu-latest
steps:
- name: Report skipped web full-stack e2e
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
web-e2e:
name: Web Full-Stack E2E
if: ${{ always() }}
needs:
- pre_job
- check-changes
- web-e2e-run
- web-e2e-skip
runs-on: ubuntu-latest
steps:
- name: Finalize Web Full-Stack E2E status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.e2e-changed }}
RUN_RESULT: ${{ needs.web-e2e-run.result }}
SKIP_RESULT: ${{ needs.web-e2e-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "Web full-stack E2E was skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "Web full-stack E2E ran successfully."
exit 0
fi
echo "Web full-stack E2E was required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "Web full-stack E2E was skipped because no E2E-related files changed."
exit 0
fi
echo "Web full-stack E2E was not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
style-check:
name: Style Check
needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true'
uses: ./.github/workflows/style.yml
vdb-tests-run:
name: Run VDB Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed == 'true'
uses: ./.github/workflows/vdb-tests.yml
vdb-tests-skip:
name: Skip VDB Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
runs-on: ubuntu-latest
steps:
- name: Report skipped VDB tests
run: echo "No VDB-related changes detected; skipping VDB tests."
vdb-tests:
name: VDB Tests
needs: check-changes
if: needs.check-changes.outputs.vdb-changed == 'true'
uses: ./.github/workflows/vdb-tests.yml
if: ${{ always() }}
needs:
- pre_job
- check-changes
- vdb-tests-run
- vdb-tests-skip
runs-on: ubuntu-latest
steps:
- name: Finalize VDB Tests status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.vdb-changed }}
RUN_RESULT: ${{ needs.vdb-tests-run.result }}
SKIP_RESULT: ${{ needs.vdb-tests-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "VDB tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "VDB tests ran successfully."
exit 0
fi
echo "VDB tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "VDB tests were skipped because no VDB-related files changed."
exit 0
fi
echo "VDB tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
db-migration-test-run:
name: Run DB Migration Test
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed == 'true'
uses: ./.github/workflows/db-migration-test.yml
db-migration-test-skip:
name: Skip DB Migration Test
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
runs-on: ubuntu-latest
steps:
- name: Report skipped DB migration tests
run: echo "No migration-related changes detected; skipping DB migration tests."
db-migration-test:
name: DB Migration Test
needs: check-changes
if: needs.check-changes.outputs.migration-changed == 'true'
uses: ./.github/workflows/db-migration-test.yml
if: ${{ always() }}
needs:
- pre_job
- check-changes
- db-migration-test-run
- db-migration-test-skip
runs-on: ubuntu-latest
steps:
- name: Finalize DB Migration Test status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.migration-changed }}
RUN_RESULT: ${{ needs.db-migration-test-run.result }}
SKIP_RESULT: ${{ needs.db-migration-test-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "DB migration tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "DB migration tests ran successfully."
exit 0
fi
echo "DB migration tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "DB migration tests were skipped because no migration-related files changed."
exit 0
fi
echo "DB migration tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1

View File

@@ -14,7 +14,6 @@ jobs:
strategy:
matrix:
python-version:
- "3.11"
- "3.12"
steps:

81
.github/workflows/web-e2e.yml vendored Normal file
View File

@@ -0,0 +1,81 @@
name: Web Full-Stack E2E
on:
workflow_call:
permissions:
contents: read
concurrency:
group: web-e2e-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: Web Full-Stack E2E
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web dependencies
uses: ./.github/actions/setup-web
- name: Install E2E package dependencies
working-directory: ./e2e
run: vp install --frozen-lockfile
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: "3.12"
cache-dependency-glob: api/uv.lock
- name: Install API dependencies
run: uv sync --project api --dev
- name: Start middleware stack
working-directory: ./e2e
run: vp run e2e:middleware:up
- name: Install Playwright browser
working-directory: ./e2e
run: vp run e2e:install
- name: Run source-api and built-web Cucumber E2E tests
working-directory: ./e2e
env:
E2E_ADMIN_EMAIL: e2e-admin@example.com
E2E_ADMIN_NAME: E2E Admin
E2E_ADMIN_PASSWORD: E2eAdmin12345
E2E_FORCE_WEB_BUILD: "1"
E2E_INIT_PASSWORD: E2eInit12345
run: vp run e2e
- name: Upload Cucumber report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: cucumber-report
path: e2e/cucumber-report
retention-days: 7
- name: Upload E2E logs
if: ${{ !cancelled() }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: e2e-logs
path: e2e/.logs
retention-days: 7
- name: Stop middleware stack
if: ${{ always() }}
working-directory: ./e2e
run: vp run e2e:middleware:down

View File

@@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_enterprise_telemetry,
ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
@@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_fastopenapi,
ext_otel,
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
]

View File

@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
from libs.file_utils import search_file_upwards
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
@@ -73,6 +73,8 @@ class DifyConfig(
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
# Enterprise telemetry configs
EnterpriseTelemetryConfig,
):
model_config = SettingsConfigDict(
# read from dotenv format config file

View File

@@ -22,3 +22,52 @@ class EnterpriseFeatureConfig(BaseSettings):
ENTERPRISE_REQUEST_TIMEOUT: int = Field(
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
)
class EnterpriseTelemetryConfig(BaseSettings):
"""
Configuration for enterprise telemetry.
"""
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
default=False,
)
ENTERPRISE_OTLP_ENDPOINT: str = Field(
description="Enterprise OTEL collector endpoint.",
default="",
)
ENTERPRISE_OTLP_HEADERS: str = Field(
description="Auth headers for OTLP export (key=value,key2=value2).",
default="",
)
ENTERPRISE_OTLP_PROTOCOL: str = Field(
description="OTLP protocol: 'http' or 'grpc' (default: http).",
default="http",
)
ENTERPRISE_OTLP_API_KEY: str = Field(
description="Bearer token for enterprise OTLP export authentication.",
default="",
)
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
description="Include input/output content in traces (privacy toggle).",
# Setting the default value to False to avoid accidentally log PII data in traces.
default=False,
)
ENTERPRISE_SERVICE_NAME: str = Field(
description="Service name for OTEL resource.",
default="dify",
)
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
default=1.0,
ge=0.0,
le=1.0,
)

View File

@@ -1366,32 +1366,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
)
class EvaluationConfig(BaseSettings):
"""
Configuration for evaluation runtime
"""
EVALUATION_FRAMEWORK: str = Field(
description="Evaluation framework to use (ragas/deepeval/none)",
default="none",
)
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
description="Maximum number of concurrent evaluation runs per tenant",
default=3,
)
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
description="Maximum number of rows allowed in an evaluation dataset",
default=500,
)
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for a single evaluation task",
default=3600,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@@ -1404,7 +1378,6 @@ class FeatureConfig(
MarketplaceConfig,
DataSetConfig,
EndpointConfig,
EvaluationConfig,
FileAccessConfig,
FileUploadConfig,
HttpConfig,

View File

@@ -107,9 +107,6 @@ from .datasets.rag_pipeline import (
rag_pipeline_workflow,
)
# Import evaluation controllers
from .evaluation import evaluation
# Import explore controllers
from .explore import (
banner,
@@ -120,9 +117,6 @@ from .explore import (
trial,
)
# Import snippet controllers
from .snippets import snippet_workflow
# Import tag controllers
from .tag import tags
@@ -136,7 +130,6 @@ from .workspace import (
model_providers,
models,
plugin,
snippets,
tool_providers,
trigger_providers,
workspace,
@@ -174,7 +167,6 @@ __all__ = [
"datasource_content_preview",
"email_register",
"endpoint",
"evaluation",
"extension",
"external",
"feature",
@@ -209,8 +201,6 @@ __all__ = [
"saved_message",
"setup",
"site",
"snippet_workflow",
"snippets",
"spec",
"statistic",
"tags",

View File

@@ -1,5 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import Literal, cast
from typing import Literal, TypedDict, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
@@ -173,6 +173,23 @@ console_ns.schema_model(
)
class HumanInputPauseTypeResponse(TypedDict):
type: Literal["human_input"]
form_id: str
backstage_input_url: str | None
class PausedNodeResponse(TypedDict):
node_id: str
node_title: str
pause_type: HumanInputPauseTypeResponse
class WorkflowPauseDetailsResponse(TypedDict):
paused_at: str | None
paused_nodes: list[PausedNodeResponse]
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs")
@@ -490,10 +507,11 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
# Check if workflow is suspended
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:
return {
empty_response: WorkflowPauseDetailsResponse = {
"paused_at": None,
"paused_nodes": [],
}, 200
}
return empty_response, 200
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
@@ -503,8 +521,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
# Build response
paused_at = pause_entity.paused_at if pause_entity else None
paused_nodes = []
response = {
paused_nodes: list[PausedNodeResponse] = []
response: WorkflowPauseDetailsResponse = {
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
"paused_nodes": paused_nodes,
}

View File

@@ -1,13 +1,10 @@
import json
from typing import Any, cast
from urllib.parse import quote
from flask import Response, request
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
@@ -27,7 +24,6 @@ from controllers.console.wraps import (
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from core.indexing_runner import IndexingRunner
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.datasource.vdb.vector_type import VectorType
@@ -36,7 +32,6 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
@@ -58,19 +53,12 @@ from fields.dataset_fields import (
from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
@@ -998,429 +986,3 @@ class DatasetAutoDisableLogApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
# ---- Knowledge Base Retrieval Evaluation ----
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"judgment": item.judgment_dict,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
class DatasetEvaluationTemplateDownloadApi(Resource):
@console_ns.doc("download_dataset_evaluation_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""Download evaluation dataset template for knowledge base retrieval."""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
class DatasetEvaluationDetailApi(Resource):
@console_ns.doc("get_dataset_evaluation_config")
@console_ns.response(200, "Evaluation configuration retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get evaluation configuration for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(
session, current_tenant_id, "dataset", dataset_id_str
)
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"metrics_config": None,
"judgement_conditions": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"metrics_config": config.metrics_config_dict,
"judgement_conditions": config.judgement_conditions_dict,
}
@console_ns.doc("save_dataset_evaluation_config")
@console_ns.response(200, "Evaluation configuration saved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id):
"""Save evaluation configuration for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type="dataset",
target_id=dataset_id_str,
account_id=str(current_user.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"metrics_config": config.metrics_config_dict,
"judgement_conditions": config.judgement_conditions_dict,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
class DatasetEvaluationRunApi(Resource):
@console_ns.doc("start_dataset_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""Start an evaluation run for the knowledge base retrieval."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
upload_file = (
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
)
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
target_id=dataset_id_str,
account_id=str(current_user.id),
dataset_file_content=dataset_content,
run_request=run_request,
)
return _serialize_dataset_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
class DatasetEvaluationLogsApi(Resource):
@console_ns.doc("get_dataset_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get evaluation run history for the knowledge base."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type="dataset",
target_id=dataset_id_str,
page=page,
page_size=page_size,
)
return {
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
class DatasetEvaluationRunDetailApi(Resource):
@console_ns.doc("get_dataset_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or run not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, run_id):
"""Get evaluation run detail including per-item results."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
run_id_str = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id_str,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id_str,
page=page,
page_size=page_size,
)
return {
"run": _serialize_dataset_evaluation_run(run),
"items": {
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
class DatasetEvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_dataset_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or run not found")
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, run_id):
"""Cancel a running knowledge base evaluation."""
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
run_id_str = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id_str,
)
return _serialize_dataset_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
class DatasetEvaluationMetricsApi(Resource):
@console_ns.doc("get_dataset_evaluation_metrics")
@console_ns.response(200, "Available retrieval metrics retrieved")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
"""Get available evaluation metrics for knowledge base retrieval."""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return {
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.KNOWLEDGE_BASE)
}
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
class DatasetEvaluationFileDownloadApi(Resource):
@console_ns.doc("download_dataset_evaluation_file")
@console_ns.response(200, "File download URL generated")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset or file not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, file_id):
"""Download evaluation test file or result file for the knowledge base."""
from core.workflow.file import helpers as file_helpers
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
file_id_str = str(file_id)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id_str,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found.")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}

View File

@@ -1 +0,0 @@
# Evaluation controller module

View File

@@ -1,642 +0,0 @@
from __future__ import annotations
import logging
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, fields
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
from graphon.file import helpers as file_helpers
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App, Dataset
from models.model import UploadFile
from models.snippet import CustomizedSnippet
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.evaluation_service import EvaluationService
if TYPE_CHECKING:
from models.evaluation import EvaluationRun, EvaluationRunItem
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Valid evaluation target types
EVALUATE_TARGET_TYPES = {"app", "snippets"}
class VersionQuery(BaseModel):
"""Query parameters for version endpoint."""
version: str
register_schema_models(
console_ns,
VersionQuery,
)
# Response field definitions
file_info_fields = {
"id": fields.String,
"name": fields.String,
}
evaluation_log_fields = {
"created_at": TimestampField,
"created_by": fields.String,
"test_file": fields.Nested(
console_ns.model(
"EvaluationTestFile",
file_info_fields,
)
),
"result_file": fields.Nested(
console_ns.model(
"EvaluationResultFile",
file_info_fields,
),
allow_null=True,
),
"version": fields.String,
}
evaluation_log_list_model = console_ns.model(
"EvaluationLogList",
{
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
},
)
customized_matrix_fields = {
"evaluation_workflow_id": fields.String,
"input_fields": fields.Raw,
"output_fields": fields.Raw,
}
condition_fields = {
"name": fields.List(fields.String),
"comparison_operator": fields.String,
"value": fields.String,
}
judgement_conditions_fields = {
"logical_operator": fields.String,
"conditions": fields.List(fields.Nested(console_ns.model("EvaluationCondition", condition_fields))),
}
evaluation_detail_fields = {
"evaluation_model": fields.String,
"evaluation_model_provider": fields.String,
"customized_matrix": fields.Nested(
console_ns.model("EvaluationCustomizedMatrix", customized_matrix_fields),
allow_null=True,
),
"judgement_conditions": fields.Nested(
console_ns.model("EvaluationJudgementConditions", judgement_conditions_fields),
allow_null=True,
),
}
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
def get_evaluation_target(view_func: Callable[P, R]):
"""
Decorator to resolve polymorphic evaluation target (app or snippet).
Validates the target_type parameter and fetches the corresponding
model (App or CustomizedSnippet) with tenant isolation.
"""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
target_type = kwargs.get("evaluate_target_type")
target_id = kwargs.get("evaluate_target_id")
if target_type not in EVALUATE_TARGET_TYPES:
raise NotFound(f"Invalid evaluation target type: {target_type}")
_, current_tenant_id = current_account_with_tenant()
target_id = str(target_id)
# Remove path parameters
del kwargs["evaluate_target_type"]
del kwargs["evaluate_target_id"]
target: Union[App, CustomizedSnippet, Dataset] | None = None
if target_type == "app":
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
elif target_type == "snippets":
target = (
db.session.query(CustomizedSnippet)
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
.first()
)
elif target_type == "knowledge":
target = (db.session.query(Dataset)
.where(Dataset.id == target_id, Dataset.tenant_id == current_tenant_id)
.first())
if not target:
raise NotFound(f"{str(target_type)} not found")
kwargs["target"] = target
kwargs["target_type"] = target_type
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
class EvaluationDatasetTemplateDownloadApi(Resource):
@console_ns.doc("download_evaluation_dataset_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(400, "Invalid target type or excluded app mode")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Download evaluation dataset template.
Generates an XLSX template based on the target's input parameters
and streams it directly as a file attachment.
"""
try:
xlsx_content, filename = EvaluationService.generate_dataset_template(
target=target,
target_type=target_type,
)
except ValueError as e:
return {"message": str(e)}, 400
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
class EvaluationDetailApi(Resource):
@console_ns.doc("get_evaluation_detail")
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation configuration for the target.
Returns evaluation configuration including model settings,
metrics config, and judgement conditions.
"""
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id))
if config is None:
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"metrics_config": None,
"judgement_conditions": None,
}
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"metrics_config": config.metrics_config_dict,
"judgement_conditions": config.judgement_conditions_dict,
}
@console_ns.doc("save_evaluation_detail")
@console_ns.response(200, "Evaluation configuration saved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Save evaluation configuration for the target.
"""
current_account, current_tenant_id = current_account_with_tenant()
body = request.get_json(force=True)
try:
config_data = EvaluationConfigData.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
with Session(db.engine, expire_on_commit=False) as session:
config = EvaluationService.save_evaluation_config(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
data=config_data,
)
return {
"evaluation_model": config.evaluation_model,
"evaluation_model_provider": config.evaluation_model_provider,
"metrics_config": config.metrics_config_dict,
"judgement_conditions": config.judgement_conditions_dict,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
class EvaluationLogsApi(Resource):
@console_ns.doc("get_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation run history for the target.
Returns a paginated list of evaluation runs.
"""
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
with Session(db.engine, expire_on_commit=False) as session:
runs, total = EvaluationService.get_evaluation_runs(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
page=page,
page_size=page_size,
)
return {
"data": [_serialize_evaluation_run(run) for run in runs],
"total": total,
"page": page,
"page_size": page_size,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
class EvaluationRunApi(Resource):
@console_ns.doc("start_evaluation_run")
@console_ns.response(200, "Evaluation run started")
@console_ns.response(400, "Invalid request")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
"""
Start an evaluation run.
Expects JSON body with:
- file_id: uploaded dataset file ID
- evaluation_model: evaluation model name
- evaluation_model_provider: evaluation model provider
- default_metrics: list of default metric objects
- customized_metrics: customized metrics object (optional)
- judgment_config: judgment conditions config (optional)
"""
current_account, current_tenant_id = current_account_with_tenant()
body = request.get_json(force=True)
if not body:
raise BadRequest("Request body is required.")
# Validate and parse request body
try:
run_request = EvaluationRunRequest.model_validate(body)
except Exception as e:
raise BadRequest(f"Invalid request body: {e}")
# Load dataset file
upload_file = (
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
)
if not upload_file:
raise NotFound("Dataset file not found.")
try:
dataset_content = storage.load_once(upload_file.key)
except Exception:
raise BadRequest("Failed to read dataset file.")
if not dataset_content:
raise BadRequest("Dataset file is empty.")
try:
with Session(db.engine, expire_on_commit=False) as session:
evaluation_run = EvaluationService.start_evaluation_run(
session=session,
tenant_id=current_tenant_id,
target_type=target_type,
target_id=str(target.id),
account_id=str(current_account.id),
dataset_file_content=dataset_content,
run_request=run_request,
)
return _serialize_evaluation_run(evaluation_run), 200
except EvaluationFrameworkNotConfiguredError as e:
return {"message": str(e.description)}, 400
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except EvaluationMaxConcurrentRunsError as e:
return {"message": str(e.description)}, 429
except EvaluationDatasetInvalidError as e:
return {"message": str(e.description)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>")
class EvaluationRunDetailApi(Resource):
@console_ns.doc("get_evaluation_run_detail")
@console_ns.response(200, "Evaluation run detail retrieved")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""
Get evaluation run detail including items.
"""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 50, type=int)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.get_evaluation_run_detail(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
items, total_items = EvaluationService.get_evaluation_run_items(
session=session,
run_id=run_id,
page=page,
page_size=page_size,
)
return {
"run": _serialize_evaluation_run(run),
"items": {
"data": [_serialize_evaluation_run_item(item) for item in items],
"total": total_items,
"page": page,
"page_size": page_size,
},
}
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel")
class EvaluationRunCancelApi(Resource):
@console_ns.doc("cancel_evaluation_run")
@console_ns.response(200, "Evaluation run cancelled")
@console_ns.response(404, "Run not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
"""Cancel a running evaluation."""
_, current_tenant_id = current_account_with_tenant()
run_id = str(run_id)
try:
with Session(db.engine, expire_on_commit=False) as session:
run = EvaluationService.cancel_evaluation_run(
session=session,
tenant_id=current_tenant_id,
run_id=run_id,
)
return _serialize_evaluation_run(run)
except EvaluationNotFoundError as e:
return {"message": str(e.description)}, 404
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
class EvaluationMetricsApi(Resource):
@console_ns.doc("get_evaluation_metrics")
@console_ns.response(200, "Available metrics retrieved")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get available evaluation metrics for the current framework.
"""
result = {}
for category in EvaluationCategory:
result[category.value] = EvaluationService.get_supported_metrics(category)
return {"metrics": result}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/node-info")
class EvaluationNodeInfoApi(Resource):
@console_ns.doc("get_evaluation_node_info")
@console_ns.response(200, "Node info grouped by metric")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""Return workflow/snippet node info grouped by requested metrics.
Request body (JSON):
- metrics: list[str] | None metric names to query; omit or pass
an empty list to get all nodes under key ``"all"``.
Response:
``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}``
"""
body = request.get_json(silent=True) or {}
metrics: list[str] | None = body.get("metrics") or None
result = EvaluationService.get_nodes_for_metrics(
target=target,
target_type=target_type,
metrics=metrics,
)
return result
@console_ns.route("/evaluation/available-metrics")
class EvaluationAvailableMetricsApi(Resource):
@console_ns.doc("get_available_evaluation_metrics")
@console_ns.response(200, "Available metrics list")
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Return the centrally-defined list of evaluation metrics."""
return {"metrics": EvaluationService.get_available_metrics()}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
class EvaluationFileDownloadApi(Resource):
@console_ns.doc("download_evaluation_file")
@console_ns.response(200, "File download URL generated successfully")
@console_ns.response(404, "Target or file not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
"""
Download evaluation test file or result file.
Looks up the specified file, verifies it belongs to the same tenant,
and returns file info and download URL.
"""
file_id = str(file_id)
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
class EvaluationVersionApi(Resource):
@console_ns.doc("get_evaluation_version_detail")
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
@console_ns.response(200, "Version details retrieved successfully")
@console_ns.response(404, "Target or version not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation target version details.
Returns the workflow graph for the specified version.
"""
version = request.args.get("version")
if not version:
return {"message": "version parameter is required"}, 400
graph = {}
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
graph = target.graph_dict
return {
"graph": graph,
}
# ---- Serialization Helpers ----
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
return {
"id": run.id,
"tenant_id": run.tenant_id,
"target_type": run.target_type,
"target_id": run.target_id,
"evaluation_config_id": run.evaluation_config_id,
"status": run.status,
"dataset_file_id": run.dataset_file_id,
"result_file_id": run.result_file_id,
"total_items": run.total_items,
"completed_items": run.completed_items,
"failed_items": run.failed_items,
"progress": run.progress,
"metrics_summary": run.metrics_summary_dict,
"error": run.error,
"created_by": run.created_by,
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
}
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
return {
"id": item.id,
"item_index": item.item_index,
"inputs": item.inputs_dict,
"expected_output": item.expected_output,
"actual_output": item.actual_output,
"metrics": item.metrics_list,
"judgment": item.judgment_dict,
"metadata": item.metadata_dict,
"error": item.error,
"overall_score": item.overall_score,
}

View File

@@ -15,6 +15,7 @@ from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
@@ -166,6 +167,7 @@ class ConsoleWorkflowEventsApi(Resource):
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
@@ -202,7 +204,7 @@ class ConsoleWorkflowEventsApi(Resource):
)
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun) -> App:
query = select(App).where(
App.id == workflow_run.app_id,
App.tenant_id == workflow_run.tenant_id,

View File

@@ -1,133 +0,0 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
class SnippetListQuery(BaseModel):
"""Query parameters for listing snippets."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
is_published: bool | None = Field(default=None, description="Filter by published status")
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
@field_validator("creators", mode="before")
@classmethod
def parse_creators(cls, value: object) -> list[str] | None:
"""Normalize creators filter from query string or list input."""
if value is None:
return None
if isinstance(value, str):
return [creator.strip() for creator in value.split(",") if creator.strip()] or None
if isinstance(value, list):
return [str(creator).strip() for creator in value if str(creator).strip()] or None
return None
class IconInfo(BaseModel):
"""Icon information model."""
icon: str | None = None
icon_type: Literal["emoji", "image"] | None = None
icon_background: str | None = None
icon_url: str | None = None
class InputFieldDefinition(BaseModel):
"""Input field definition for snippet parameters."""
default: str | None = None
hint: bool | None = None
label: str | None = None
max_length: int | None = None
options: list[str] | None = None
placeholder: str | None = None
required: bool | None = None
type: str | None = None # e.g., "text-input"
class CreateSnippetPayload(BaseModel):
"""Payload for creating a new snippet."""
name: str = Field(..., min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
type: Literal["node", "group"] = "node"
icon_info: IconInfo | None = None
graph: dict[str, Any] | None = None
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
class UpdateSnippetPayload(BaseModel):
"""Payload for updating a snippet."""
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
icon_info: IconInfo | None = None
class SnippetDraftSyncPayload(BaseModel):
"""Payload for syncing snippet draft workflow."""
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
input_variables: list[dict[str, Any]] | None = None
class WorkflowRunQuery(BaseModel):
"""Query parameters for workflow runs."""
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
class SnippetDraftRunPayload(BaseModel):
"""Payload for running snippet draft workflow."""
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class SnippetDraftNodeRunPayload(BaseModel):
"""Payload for running a single node in snippet draft workflow."""
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
class SnippetIterationNodeRunPayload(BaseModel):
"""Payload for running an iteration node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class SnippetLoopNodeRunPayload(BaseModel):
"""Payload for running a loop node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class PublishWorkflowPayload(BaseModel):
"""Payload for publishing snippet workflow."""
knowledge_base_setting: dict[str, Any] | None = None
class SnippetImportPayload(BaseModel):
"""Payload for importing snippet from DSL."""
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
name: str | None = Field(default=None, description="Override snippet name")
description: str | None = Field(default=None, description="Override snippet description")
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
class IncludeSecretQuery(BaseModel):
"""Query parameter for including secret variables in export."""
include_secret: str = Field(default="false", description="Whether to include secret variables")

View File

@@ -1,541 +0,0 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow import workflow_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_node_execution_model,
workflow_run_pagination_model,
)
from controllers.console.snippets.payloads import (
PublishWorkflowPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetDraftSyncPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
WorkflowRunQuery,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from graphon.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import variable_factory
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.snippet import CustomizedSnippet
from services.errors.app import WorkflowHashNotEqualError
from services.snippet_generate_service import SnippetGenerateService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetDraftSyncPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
WorkflowRunQuery,
PublishWorkflowPayload,
)
class SnippetNotFoundError(Exception):
"""Snippet not found error."""
pass
def get_snippet(view_func: Callable[P, R]):
"""Decorator to fetch and validate snippet access."""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("snippet_id"):
raise ValueError("missing snippet_id in path parameters")
_, current_tenant_id = current_account_with_tenant()
snippet_id = str(kwargs.get("snippet_id"))
del kwargs["snippet_id"]
snippet = SnippetService.get_snippet_by_id(
snippet_id=snippet_id,
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
kwargs["snippet"] = snippet
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
class SnippetDraftWorkflowApi(Resource):
@console_ns.doc("get_snippet_draft_workflow")
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get draft workflow for snippet."""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise DraftWorkflowNotExist()
return workflow
@console_ns.doc("sync_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
@console_ns.response(200, "Draft workflow synced successfully")
@console_ns.response(400, "Hash mismatch")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Sync draft workflow for snippet."""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
try:
environment_variables_list = payload.environment_variables or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = payload.conversation_variables or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
snippet_service = SnippetService()
workflow = snippet_service.sync_draft_workflow(
snippet=snippet,
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
input_variables=payload.input_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
class SnippetDraftConfigApi(Resource):
@console_ns.doc("get_snippet_draft_config")
@console_ns.response(200, "Draft config retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get snippet draft workflow configuration limits."""
return {
"parallel_depth_limit": 3,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
class SnippetPublishedWorkflowApi(Resource):
@console_ns.doc("get_snippet_published_workflow")
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get published workflow for snippet."""
if not snippet.is_published:
return None
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
return workflow
@console_ns.doc("publish_snippet_workflow")
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
@console_ns.response(200, "Workflow published successfully")
@console_ns.response(400, "No draft workflow found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Publish snippet workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
with Session(db.engine) as session:
snippet = session.merge(snippet)
try:
workflow = snippet_service.publish_workflow(
session=session,
snippet=snippet,
account=current_user,
)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
class SnippetDefaultBlockConfigsApi(Resource):
@console_ns.doc("get_snippet_default_block_configs")
@console_ns.response(200, "Default block configs retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get default block configurations for snippet workflow."""
snippet_service = SnippetService()
return snippet_service.get_default_block_configs()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
class SnippetWorkflowRunsApi(Resource):
@console_ns.doc("list_snippet_workflow_runs")
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_pagination_model)
def get(self, snippet: CustomizedSnippet):
"""List workflow runs for snippet."""
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": query.last_id,
"limit": query.limit,
}
snippet_service = SnippetService()
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
return result
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
class SnippetWorkflowRunDetailApi(Resource):
@console_ns.doc("get_snippet_workflow_run_detail")
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_detail_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""Get workflow run detail for snippet."""
run_id = str(run_id)
snippet_service = SnippetService()
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
raise NotFound("Workflow run not found")
return workflow_run
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
class SnippetWorkflowRunNodeExecutionsApi(Resource):
@console_ns.doc("list_snippet_workflow_run_node_executions")
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_list_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""List node executions for a workflow run."""
run_id = str(run_id)
snippet_service = SnippetService()
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
snippet=snippet,
run_id=run_id,
)
return {"data": node_executions}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
class SnippetDraftNodeRunApi(Resource):
@console_ns.doc("run_snippet_draft_node")
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a single node in snippet draft workflow.
Executes a specific node with provided inputs for single-step debugging.
Returns the node execution result including status, outputs, and timing.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
user_inputs = payload.inputs
# Get draft workflow for file parsing
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
workflow_node_execution = SnippetGenerateService.run_draft_node(
snippet=snippet,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=payload.query,
files=files,
)
return workflow_node_execution
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
class SnippetDraftNodeLastRunApi(Resource):
@console_ns.doc("get_snippet_draft_node_last_run")
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
def get(self, snippet: CustomizedSnippet, node_id: str):
"""
Get the last run result for a specific node in snippet draft workflow.
Returns the most recent execution record for the given node,
including status, inputs, outputs, and timing information.
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
node_exec = snippet_service.get_snippet_node_last_run(
snippet=snippet,
workflow=draft_workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("Node last run not found")
return node_exec
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class SnippetDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_snippet_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow iteration node for snippet.
Iteration nodes execute their internal sub-graph multiple times over an input list.
Returns an SSE event stream with iteration progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate_single_iteration(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class SnippetDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_snippet_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow loop node for snippet.
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
Returns an SSE event stream with loop progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = SnippetGenerateService.generate_single_loop(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
class SnippetDraftWorkflowRunApi(Resource):
@console_ns.doc("run_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""
Run draft workflow for snippet.
Executes the snippet's draft workflow with the provided inputs
and returns an SSE event stream with execution progress and results.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate(
snippet=snippet,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
class SnippetWorkflowTaskStopApi(Resource):
@console_ns.doc("stop_snippet_workflow_task")
@console_ns.response(200, "Task stopped successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, task_id: str):
"""
Stop a running snippet workflow task.
Uses both the legacy stop flag mechanism and the graph engine
command channel for backward compatibility.
"""
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@@ -200,7 +200,7 @@ class PluginDebuggingKeyApi(Resource):
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/list")
@@ -215,7 +215,7 @@ class PluginListApi(Resource):
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@@ -232,7 +232,7 @@ class PluginListLatestVersionsApi(Resource):
try:
versions = PluginService.list_latest_versions(args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder({"versions": versions})
@@ -251,7 +251,7 @@ class PluginListInstallationsFromIdsApi(Resource):
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder({"plugins": plugins})
@@ -266,7 +266,7 @@ class PluginIconApi(Resource):
try:
icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@@ -286,7 +286,7 @@ class PluginAssetApi(Resource):
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/upload/pkg")
@@ -303,7 +303,7 @@ class PluginUploadFromPkgApi(Resource):
try:
response = PluginService.upload_pkg(tenant_id, content)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder(response)
@@ -323,7 +323,7 @@ class PluginUploadFromGithubApi(Resource):
try:
response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder(response)
@@ -361,7 +361,7 @@ class PluginInstallFromPkgApi(Resource):
try:
response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder(response)
@@ -387,7 +387,7 @@ class PluginInstallFromGithubApi(Resource):
args.package,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder(response)
@@ -407,7 +407,7 @@ class PluginInstallFromMarketplaceApi(Resource):
try:
response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder(response)
@@ -433,7 +433,7 @@ class PluginFetchMarketplacePkgApi(Resource):
}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
@@ -453,7 +453,7 @@ class PluginFetchManifestApi(Resource):
{"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/tasks")
@@ -471,7 +471,7 @@ class PluginFetchInstallTasksApi(Resource):
try:
return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>")
@@ -486,7 +486,7 @@ class PluginFetchInstallTaskApi(Resource):
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
@@ -501,7 +501,7 @@ class PluginDeleteInstallTaskApi(Resource):
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
@@ -516,7 +516,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
@@ -531,7 +531,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
@@ -553,7 +553,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/upgrade/github")
@@ -580,7 +580,7 @@ class PluginUpgradeFromGithubApi(Resource):
)
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/uninstall")
@@ -598,7 +598,7 @@ class PluginUninstallApi(Resource):
try:
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
@console_ns.route("/workspaces/current/plugin/permission/change")
@@ -674,7 +674,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
provider_type=args.provider_type,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder({"options": options})
@@ -705,7 +705,7 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
credentials=args.credentials,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return {"code": "plugin_error", "message": e.description}, 400
return jsonable_encoder({"options": options})

View File

@@ -1,380 +0,0 @@
import logging
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, marshal
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.snippets.payloads import (
CreateSnippetPayload,
IncludeSecretQuery,
SnippetImportPayload,
SnippetListQuery,
UpdateSnippetPayload,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from models.snippet import SnippetType
from services.app_dsl_service import ImportStatus
from services.snippet_dsl_service import SnippetDslService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetListQuery,
CreateSnippetPayload,
UpdateSnippetPayload,
SnippetImportPayload,
IncludeSecretQuery,
)
# Create namespace models for marshaling
snippet_model = console_ns.model("Snippet", snippet_fields)
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
@console_ns.route("/workspaces/current/customized-snippets")
class CustomizedSnippetsApi(Resource):
@console_ns.doc("list_customized_snippets")
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query_params = request.args.to_dict()
query = SnippetListQuery.model_validate(query_params)
snippets, total, has_more = SnippetService.get_snippets(
tenant_id=current_tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
is_published=query.is_published,
creators=query.creators,
)
return {
"data": marshal(snippets, snippet_list_fields),
"page": query.page,
"limit": query.limit,
"total": total,
"has_more": has_more,
}, 200
@console_ns.doc("create_customized_snippet")
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
@console_ns.response(201, "Snippet created successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
snippet_type = SnippetType(payload.type)
except ValueError:
snippet_type = SnippetType.NODE
try:
snippet = SnippetService.create_snippet(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
snippet_type=snippet_type,
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
account=current_user,
)
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 201
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
class CustomizedSnippetDetailApi(Resource):
@console_ns.doc("get_customized_snippet")
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
return marshal(snippet, snippet_fields), 200
@console_ns.doc("update_customized_snippet")
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
@console_ns.response(200, "Snippet updated successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if "icon_info" in update_data and update_data["icon_info"] is not None:
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
if not update_data:
return {"message": "No valid fields to update"}, 400
try:
with Session(db.engine, expire_on_commit=False) as session:
snippet = session.merge(snippet)
snippet = SnippetService.update_snippet(
session=session,
snippet=snippet,
account_id=current_user.id,
data=update_data,
)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 200
@console_ns.doc("delete_customized_snippet")
@console_ns.response(204, "Snippet deleted successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.delete_snippet(
session=session,
snippet=snippet,
)
session.commit()
return "", 204
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/export")
class CustomizedSnippetExportApi(Resource):
@console_ns.doc("export_customized_snippet")
@console_ns.doc(description="Export snippet configuration as DSL")
@console_ns.doc(params={"snippet_id": "Snippet ID to export"})
@console_ns.response(200, "Snippet exported successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Export snippet as DSL."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
# Get include_secret parameter
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
export_service = SnippetDslService(session)
result = export_service.export_snippet_dsl(snippet=snippet, include_secret=query.include_secret == "true")
# Set filename with .snippet extension
filename = f"{snippet.name}.snippet"
encoded_filename = quote(filename)
response = Response(
result,
mimetype="application/x-yaml",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/x-yaml"
return response
@console_ns.route("/workspaces/current/customized-snippets/imports")
class CustomizedSnippetImportApi(Resource):
@console_ns.doc("import_customized_snippet")
@console_ns.doc(description="Import snippet from DSL")
@console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__))
@console_ns.response(200, "Snippet imported successfully")
@console_ns.response(202, "Import pending confirmation")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Import snippet from DSL."""
current_user, _ = current_account_with_tenant()
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.import_snippet(
account=current_user,
import_mode=payload.mode,
yaml_content=payload.yaml_content,
yaml_url=payload.yaml_url,
snippet_id=payload.snippet_id,
name=payload.name,
description=payload.description,
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/imports/<string:import_id>/confirm")
class CustomizedSnippetImportConfirmApi(Resource):
@console_ns.doc("confirm_snippet_import")
@console_ns.doc(description="Confirm a pending snippet import")
@console_ns.doc(params={"import_id": "Import ID to confirm"})
@console_ns.response(200, "Import confirmed successfully")
@console_ns.response(400, "Import failed")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
"""Confirm a pending snippet import."""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.confirm_import(import_id=import_id, account=current_user)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/check-dependencies")
class CustomizedSnippetCheckDependenciesApi(Resource):
@console_ns.doc("check_snippet_dependencies")
@console_ns.doc(description="Check dependencies for a snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Dependencies checked successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
"""Check dependencies for a snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.check_dependencies(snippet=snippet)
return result.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>/use-count/increment")
class CustomizedSnippetUseCountIncrementApi(Resource):
@console_ns.doc("increment_snippet_use_count")
@console_ns.doc(description="Increment snippet use count by 1")
@console_ns.doc(params={"snippet_id": "Snippet ID"})
@console_ns.response(200, "Use count incremented successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, snippet_id: str):
"""Increment snippet use count when it is inserted into a workflow."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.increment_use_count(session=session, snippet=snippet)
session.commit()
session.refresh(snippet)
return {"result": "success", "use_count": snippet.use_count}, 200

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -22,7 +22,12 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.advanced_chat.generate_task_pipeline import (
AdvancedChatAppGenerateTaskPipeline,
ConversationSnapshot,
MessageSnapshot,
WorkflowSnapshot,
)
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.draft_variable_saver import DraftVariableSaverFactory
from core.app.apps.exc import GenerateTaskStoppedError
@@ -44,7 +49,6 @@ from graphon.runtime import GraphRuntimeState
from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -524,19 +528,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# Capture the scalar fields needed by the response pipeline before
# releasing the request-scoped SQLAlchemy session.
workflow_snapshot = WorkflowSnapshot.from_workflow(workflow)
conversation_snapshot = ConversationSnapshot.from_conversation(conversation)
message_snapshot = MessageSnapshot.from_message(message)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
workflow=workflow_snapshot,
queue_manager=queue_manager,
conversation=conversation,
message=message,
conversation=conversation_snapshot,
message=message_snapshot,
user=user,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
@@ -643,10 +648,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
workflow: WorkflowSnapshot,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
@@ -683,13 +688,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
_T = TypeVar("_T", bound=Base)
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model

View File

@@ -4,6 +4,8 @@ import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from threading import Thread
from typing import Any, Union
@@ -79,11 +81,59 @@ from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.model import AppMode
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class WorkflowSnapshot:
id: str
tenant_id: str
features_dict: Mapping[str, Any]
@classmethod
def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot":
return cls(
id=workflow.id,
tenant_id=workflow.tenant_id,
features_dict=dict(workflow.features_dict),
)
@dataclass(frozen=True, slots=True)
class ConversationSnapshot:
id: str
mode: AppMode
@classmethod
def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot":
return cls(
id=conversation.id,
mode=conversation.mode,
)
@dataclass(frozen=True, slots=True)
class MessageSnapshot:
id: str
query: str
created_at: datetime
status: MessageStatus
answer: str
@classmethod
def from_message(cls, message: Message) -> "MessageSnapshot":
return cls(
id=message.id,
query=message.query,
created_at=message.created_at,
status=message.status,
answer=message.answer,
)
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
@@ -92,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
workflow: WorkflowSnapshot,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
@@ -156,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._message_saved_on_pause = False
self._seed_graph_runtime_state_from_queue_manager()
def _seed_task_state_from_message(self, message: Message) -> None:
def _seed_task_state_from_message(self, message: MessageSnapshot) -> None:
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer

View File

@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -54,25 +54,6 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _ensure_snippet_start_node_in_worker(*, session: Session, workflow: Workflow) -> Workflow:
"""Re-apply snippet virtual Start injection after worker reloads workflow from DB."""
if workflow.type != "snippet":
return workflow
from models.snippet import CustomizedSnippet
from services.snippet_generate_service import SnippetGenerateService
snippet = session.scalar(
select(CustomizedSnippet).where(
CustomizedSnippet.id == workflow.app_id,
CustomizedSnippet.tenant_id == workflow.tenant_id,
)
)
if snippet is None:
return workflow
return SnippetGenerateService.ensure_start_node_for_worker(workflow, snippet)
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@@ -576,8 +557,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
if workflow is None:
raise ValueError("Workflow not found")
workflow = self._ensure_snippet_start_node_in_worker(session=session, workflow=workflow)
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,

View File

@@ -1,271 +0,0 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
)
from graphon.node_events.base import NodeRunResult
logger = logging.getLogger(__name__)
class BaseEvaluationInstance(ABC):
"""Abstract base class for evaluation framework adapters."""
@abstractmethod
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate LLM outputs using the configured framework."""
...
@abstractmethod
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate retrieval quality using the configured framework."""
...
@abstractmethod
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate agent outputs using the configured framework."""
...
@abstractmethod
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
"""Return the list of supported metric names for a given evaluation category."""
...
def evaluate_with_customized_workflow(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
customized_metrics: CustomizedMetrics,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Evaluate using a published workflow as the evaluator.
The evaluator workflow's output variables are treated as metrics:
each output variable name becomes a metric name, and its value
becomes the score.
Args:
node_run_result_mapping_list: One mapping per test-data item,
where each mapping is ``{node_id: NodeRunResult}`` from the
target execution.
customized_metrics: Contains ``evaluation_workflow_id`` (the
published evaluator workflow) and ``input_fields`` (value
sources for the evaluator's input variables).
tenant_id: Tenant scope.
Returns:
A list of ``EvaluationItemResult`` with metrics extracted from
the evaluator workflow's output variables.
"""
from sqlalchemy.orm import Session
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_app
from models.engine import db
from models.model import App
from services.workflow_service import WorkflowService
workflow_id = customized_metrics.evaluation_workflow_id
if not workflow_id:
raise ValueError("customized_metrics must contain 'evaluation_workflow_id' for customized evaluator")
# Load the evaluator workflow resources using a dedicated session
with Session(db.engine, expire_on_commit=False) as session, session.begin():
app = session.query(App).filter_by(id=workflow_id, tenant_id=tenant_id).first()
if not app:
raise ValueError(f"Evaluation workflow app {workflow_id} not found in tenant {tenant_id}")
service_account = get_service_account_for_app(session, workflow_id)
workflow_service = WorkflowService()
published_workflow = workflow_service.get_published_workflow(app_model=app)
if not published_workflow:
raise ValueError(f"No published workflow found for evaluation app {workflow_id}")
eval_results: list[EvaluationItemResult] = []
for idx, node_run_result_mapping in enumerate(node_run_result_mapping_list):
try:
workflow_inputs = self._build_workflow_inputs(
customized_metrics.input_fields,
node_run_result_mapping,
)
generator = WorkflowAppGenerator()
response: Mapping[str, Any] = generator.generate(
app_model=app,
workflow=published_workflow,
user=service_account,
args={"inputs": workflow_inputs},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
)
metrics = self._extract_workflow_metrics(response)
eval_results.append(
EvaluationItemResult(
index=idx,
metrics=metrics,
)
)
except Exception:
logger.exception(
"Customized evaluator failed for item %d with workflow %s",
idx,
workflow_id,
)
eval_results.append(EvaluationItemResult(index=idx))
return eval_results
@staticmethod
def _build_workflow_inputs(
input_fields: dict[str, Any],
node_run_result_mapping: dict[str, NodeRunResult],
) -> dict[str, Any]:
"""Build customized workflow inputs by resolving value sources.
Each entry in ``input_fields`` maps a workflow input variable name
to its value source, which can be:
- **Constant**: a plain string without ``{{#…#}}`` used as-is.
- **Expression**: a string containing one or more
``{{#node_id.output_key#}}`` selectors (same format as
``VariableTemplateParser``) resolved from
``node_run_result_mapping``.
"""
from graphon.nodes.base.variable_template_parser import REGEX as VARIABLE_REGEX
workflow_inputs: dict[str, Any] = {}
for field_name, value_source in input_fields.items():
if not isinstance(value_source, str):
# Non-string values (numbers, bools, dicts) are used directly.
workflow_inputs[field_name] = value_source
continue
# Check if the entire value is a single expression.
full_match = VARIABLE_REGEX.fullmatch(value_source)
if full_match:
workflow_inputs[field_name] = resolve_variable_selector(
full_match.group(1),
node_run_result_mapping,
)
elif VARIABLE_REGEX.search(value_source):
# Mixed template: interpolate all expressions as strings.
workflow_inputs[field_name] = VARIABLE_REGEX.sub(
lambda m: str(resolve_variable_selector(m.group(1), node_run_result_mapping)),
value_source,
)
else:
# Plain constant — no expression markers.
workflow_inputs[field_name] = value_source
return workflow_inputs
@staticmethod
def _extract_workflow_metrics(
response: Mapping[str, object],
) -> list[EvaluationMetric]:
"""Extract evaluation metrics from workflow output variables."""
metrics: list[EvaluationMetric] = []
data = response.get("data")
if not isinstance(data, Mapping):
logger.warning("Unexpected workflow response format: missing 'data' dict")
return metrics
outputs = data.get("outputs")
if not isinstance(outputs, dict):
logger.warning("Unexpected workflow response format: 'outputs' is not a dict")
return metrics
for key, raw_value in outputs.items():
if not isinstance(key, str):
continue
metrics.append(EvaluationMetric(name=key, value=raw_value))
return metrics
def resolve_variable_selector(
selector_raw: str,
node_run_result_mapping: dict[str, NodeRunResult],
) -> object:
"""
Resolve a ``#node_id.output_key#`` selector against node run results.
"""
#
cleaned = selector_raw.strip("#")
parts = cleaned.split(".")
if len(parts) < 2:
logger.warning(
"Selector '%s' must have at least node_id.output_key",
selector_raw,
)
return ""
node_id = parts[0]
output_path = parts[1:]
node_result = node_run_result_mapping.get(node_id)
if not node_result or not node_result.outputs:
logger.warning(
"Selector '%s': node '%s' not found or has no outputs",
selector_raw,
node_id,
)
return ""
# Traverse the output path to support nested keys.
current: object = node_result.outputs
for key in output_path:
if isinstance(current, Mapping):
next_val = current.get(key)
if next_val is None:
logger.warning(
"Selector '%s': key '%s' not found in node '%s' outputs",
selector_raw,
key,
node_id,
)
return ""
current = next_val
else:
logger.warning(
"Selector '%s': cannot traverse into non-dict value at key '%s'",
selector_raw,
key,
)
return ""
return current if current is not None else ""

View File

@@ -1,27 +0,0 @@
from enum import StrEnum
from pydantic import BaseModel
class EvaluationFrameworkEnum(StrEnum):
RAGAS = "ragas"
DEEPEVAL = "deepeval"
NONE = "none"
class BaseEvaluationConfig(BaseModel):
"""Base configuration for evaluation frameworks."""
pass
class RagasConfig(BaseEvaluationConfig):
"""RAGAS-specific configuration."""
pass
class DeepEvalConfig(BaseEvaluationConfig):
"""DeepEval-specific configuration."""
pass

View File

@@ -1,212 +0,0 @@
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from core.evaluation.entities.judgment_entity import JudgmentConfig, JudgmentResult
class EvaluationCategory(StrEnum):
LLM = "llm"
RETRIEVAL = "knowledge_retrieval"
AGENT = "agent"
WORKFLOW = "workflow"
SNIPPET = "snippet"
KNOWLEDGE_BASE = "knowledge_base"
class EvaluationMetricName(StrEnum):
"""Canonical metric names shared across all evaluation frameworks.
Each framework maps these names to its own internal implementation.
A framework that does not support a given metric should log a warning
and skip it rather than raising an error.
── LLM / general text-quality metrics ──────────────────────────────────
FAITHFULNESS
Measures whether every claim in the model's response is grounded in
the provided retrieved context. A high score means the answer
contains no hallucinated content — each statement can be traced back
to a passage in the context.
Required fields: user_input, response, retrieved_contexts.
ANSWER_RELEVANCY
Measures how well the model's response addresses the user's question.
A high score means the answer stays on-topic; a low score indicates
irrelevant content or a failure to answer the actual question.
Required fields: user_input, response.
ANSWER_CORRECTNESS
Measures the factual accuracy and completeness of the model's answer
relative to a ground-truth reference. It combines semantic similarity
with key-fact coverage, so both meaning and content matter.
Required fields: user_input, response, reference (expected_output).
SEMANTIC_SIMILARITY
Measures the cosine similarity between the model's response and the
reference answer in an embedding space. It evaluates whether the two
texts convey the same meaning, independent of factual correctness.
Required fields: response, reference (expected_output).
── Retrieval-quality metrics ────────────────────────────────────────────
CONTEXT_PRECISION
Measures the proportion of retrieved context chunks that are actually
relevant to the question (precision). A high score means the retrieval
pipeline returns little noise.
Required fields: user_input, reference, retrieved_contexts.
CONTEXT_RECALL
Measures the proportion of ground-truth information that is covered by
the retrieved context chunks (recall). A high score means the retrieval
pipeline does not miss important supporting evidence.
Required fields: user_input, reference, retrieved_contexts.
CONTEXT_RELEVANCE
Measures how relevant each individual retrieved chunk is to the query.
Similar to CONTEXT_PRECISION but evaluated at the chunk level rather
than against a reference answer.
Required fields: user_input, retrieved_contexts.
── Agent-quality metrics ────────────────────────────────────────────────
TOOL_CORRECTNESS
Measures the correctness of the tool calls made by the agent during
task execution — both the choice of tool and the arguments passed.
A high score means the agent's tool-use strategy matches the expected
behavior.
Required fields: actual tool calls vs. expected tool calls.
TASK_COMPLETION
Measures whether the agent ultimately achieves the user's stated goal.
It evaluates the reasoning chain, intermediate steps, and final output
holistically; a high score means the task was fully accomplished.
Required fields: user_input, actual_output.
"""
# LLM / general text-quality metrics
FAITHFULNESS = "faithfulness"
ANSWER_RELEVANCY = "answer_relevancy"
ANSWER_CORRECTNESS = "answer_correctness"
SEMANTIC_SIMILARITY = "semantic_similarity"
# Retrieval-quality metrics
CONTEXT_PRECISION = "context_precision"
CONTEXT_RECALL = "context_recall"
CONTEXT_RELEVANCE = "context_relevance"
# Agent-quality metrics
TOOL_CORRECTNESS = "tool_correctness"
TASK_COMPLETION = "task_completion"
# Per-category canonical metric lists used by get_supported_metrics().
LLM_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.FAITHFULNESS, # Every claim is grounded in context; no hallucinations
EvaluationMetricName.ANSWER_RELEVANCY, # Response stays on-topic and addresses the question
EvaluationMetricName.ANSWER_CORRECTNESS, # Factual accuracy and completeness vs. reference
EvaluationMetricName.SEMANTIC_SIMILARITY, # Semantic closeness to the reference answer
]
RETRIEVAL_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.CONTEXT_PRECISION, # Fraction of retrieved chunks that are relevant (precision)
EvaluationMetricName.CONTEXT_RECALL, # Fraction of ground-truth info covered by retrieval (recall)
EvaluationMetricName.CONTEXT_RELEVANCE, # Per-chunk relevance to the query
]
AGENT_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.TOOL_CORRECTNESS, # Correct tool selection and arguments
EvaluationMetricName.TASK_COMPLETION, # Whether the agent fully achieves the user's goal
]
WORKFLOW_METRIC_NAMES: list[EvaluationMetricName] = [
EvaluationMetricName.FAITHFULNESS,
EvaluationMetricName.ANSWER_RELEVANCY,
EvaluationMetricName.ANSWER_CORRECTNESS,
]
METRIC_NODE_TYPE_MAPPING: dict[str, str] = {
**{m.value: "llm" for m in LLM_METRIC_NAMES},
**{m.value: "knowledge-retrieval" for m in RETRIEVAL_METRIC_NAMES},
**{m.value: "agent" for m in AGENT_METRIC_NAMES},
}
class EvaluationMetric(BaseModel):
name: str
value: Any
details: dict[str, Any] = Field(default_factory=dict)
class EvaluationItemInput(BaseModel):
index: int
inputs: dict[str, Any]
output: str
expected_output: str | None = None
context: list[str] | None = None
class EvaluationDatasetInput(BaseModel):
index: int
inputs: dict[str, Any]
expected_output: str | None = None
class EvaluationItemResult(BaseModel):
index: int
actual_output: str | None = None
metrics: list[EvaluationMetric] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
judgment: JudgmentResult = Field(default_factory=JudgmentResult)
error: str | None = None
class NodeInfo(BaseModel):
node_id: str
type: str
title: str
class DefaultMetric(BaseModel):
metric: str
node_info_list: list[NodeInfo]
class CustomizedMetricOutputField(BaseModel):
variable: str
value_type: str
class CustomizedMetrics(BaseModel):
evaluation_workflow_id: str
input_fields: dict[str, Any]
output_fields: list[CustomizedMetricOutputField]
class EvaluationConfigData(BaseModel):
"""Structured data for saving evaluation configuration."""
evaluation_model: str = ""
evaluation_model_provider: str = ""
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
class EvaluationRunRequest(EvaluationConfigData):
"""Request body for starting an evaluation run."""
file_id: str
class EvaluationRunData(BaseModel):
"""Serializable data for Celery task."""
evaluation_run_id: str
tenant_id: str
target_type: str
target_id: str
evaluation_model_provider: str
evaluation_model: str
default_metrics: list[DefaultMetric] = Field(default_factory=list)
customized_metrics: CustomizedMetrics | None = None
judgment_config: JudgmentConfig | None = None
input_list: list[EvaluationDatasetInput]

View File

@@ -1,129 +0,0 @@
"""Judgment condition entities for evaluation metric assessment.
Key concepts:
- **condition_type**: Determines operator semantics and type coercion.
- "string": string operators (contains, is, start with, …).
- "number": numeric operators (>, <, =, ≠, ≥, ≤).
- "datetime": temporal operators (before, after).
Typical usage:
judgment_config = JudgmentConfig(
logical_operator="and",
conditions=[
JudgmentCondition(
metric_name="faithfulness",
comparison_operator=">",
condition_value="0.8",
condition_type="number",
)
],
)
"""
from enum import StrEnum
from typing import Any, Literal
from pydantic import BaseModel, Field
class JudgmentConditionType(StrEnum):
"""Category of the condition, controls operator semantics and type coercion."""
STRING = "string"
NUMBER = "number"
DATETIME = "datetime"
# Supported comparison operators for judgment conditions.
JudgmentComparisonOperator = Literal[
# string
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# number
"=",
"",
">",
"<",
"",
"",
# datetime
"before",
"after",
# universal
"null",
"not null",
]
class JudgmentCondition(BaseModel):
"""A single judgment condition that checks one metric value.
Attributes:
metric_name: The name of the evaluation metric to check (left side).
Must match an EvaluationMetric.name in the results.
comparison_operator: The comparison operator to apply.
condition_value: The comparison target (right side). For unary operators
such as ``empty`` or ``null`` this can be ``None``.
condition_type: Controls type coercion and which operators are valid.
"string" (default), "number", or "datetime".
"""
metric_name: str
comparison_operator: JudgmentComparisonOperator
condition_value: Any | None = None
condition_type: JudgmentConditionType = JudgmentConditionType.STRING
class JudgmentConfig(BaseModel):
"""A group of judgment conditions combined with a logical operator.
Attributes:
logical_operator: How to combine condition results — "and" requires
all conditions to pass, "or" requires at least one.
conditions: The list of individual conditions to evaluate.
"""
logical_operator: Literal["and", "or"] = "and"
conditions: list[JudgmentCondition] = Field(default_factory=list)
class JudgmentConditionResult(BaseModel):
"""Result of evaluating a single judgment condition.
Attributes:
metric_name: Which metric was checked.
comparison_operator: The operator that was applied.
expected_value: The resolved comparison value (after variable resolution).
actual_value: The actual metric value that was evaluated.
passed: Whether this individual condition passed.
error: Error message if the condition evaluation failed.
"""
metric_name: str
comparison_operator: str
expected_value: Any = None
actual_value: Any = None
passed: bool = False
error: str | None = None
class JudgmentResult(BaseModel):
"""Overall result of evaluating all judgment conditions for one item.
Attributes:
passed: Whether the overall judgment passed (based on logical_operator).
logical_operator: The logical operator used to combine conditions.
condition_results: Detailed result for each individual condition.
"""
passed: bool = False
logical_operator: Literal["and", "or"] = "and"
condition_results: list[JudgmentConditionResult] = Field(default_factory=list)

View File

@@ -1,61 +0,0 @@
import collections
import logging
from typing import Any
from configs import dify_config
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import EvaluationFrameworkEnum
from core.evaluation.entities.evaluation_entity import EvaluationCategory
logger = logging.getLogger(__name__)
class EvaluationFrameworkConfigMap(collections.UserDict[str, dict[str, Any]]):
"""Registry mapping framework enum -> {config_class, evaluator_class}."""
def __getitem__(self, framework: str) -> dict[str, Any]:
match framework:
case EvaluationFrameworkEnum.RAGAS:
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.frameworks.ragas.ragas_evaluator import RagasEvaluator
return {
"config_class": RagasConfig,
"evaluator_class": RagasEvaluator,
}
case EvaluationFrameworkEnum.DEEPEVAL:
raise NotImplementedError("DeepEval adapter is not yet implemented.")
case _:
raise ValueError(f"Unknown evaluation framework: {framework}")
evaluation_framework_config_map = EvaluationFrameworkConfigMap()
class EvaluationManager:
"""Factory for evaluation instances based on global configuration."""
@staticmethod
def get_evaluation_instance() -> BaseEvaluationInstance | None:
"""Create and return an evaluation instance based on EVALUATION_FRAMEWORK env var."""
framework = dify_config.EVALUATION_FRAMEWORK
if not framework or framework == EvaluationFrameworkEnum.NONE:
return None
try:
config_map = evaluation_framework_config_map[framework]
evaluator_class = config_map["evaluator_class"]
config_class = config_map["config_class"]
config = config_class()
return evaluator_class(config)
except Exception:
logger.exception("Failed to create evaluation instance for framework: %s", framework)
return None
@staticmethod
def get_supported_metrics(category: EvaluationCategory) -> list[str]:
"""Return supported metrics for the current framework and given category."""
instance = EvaluationManager.get_evaluation_instance()
if instance is None:
return []
return instance.get_supported_metrics(category)

View File

@@ -1,299 +0,0 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import DeepEvalConfig
from core.evaluation.entities.evaluation_entity import (
AGENT_METRIC_NAMES,
LLM_METRIC_NAMES,
RETRIEVAL_METRIC_NAMES,
WORKFLOW_METRIC_NAMES,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
EvaluationMetricName,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Maps canonical EvaluationMetricName to the corresponding deepeval metric class name.
# deepeval metric field requirements (LLMTestCase fields):
# - faithfulness: input, actual_output, retrieval_context
# - answer_relevancy: input, actual_output
# - context_precision: input, actual_output, expected_output, retrieval_context
# - context_recall: input, actual_output, expected_output, retrieval_context
# - context_relevance: input, actual_output, retrieval_context
# - tool_correctness: input, actual_output, expected_tools
# - task_completion: input, actual_output
# Metrics not listed here are unsupported by deepeval and will be skipped.
_DEEPEVAL_METRIC_MAP: dict[EvaluationMetricName, str] = {
EvaluationMetricName.FAITHFULNESS: "FaithfulnessMetric",
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancyMetric",
EvaluationMetricName.CONTEXT_PRECISION: "ContextualPrecisionMetric",
EvaluationMetricName.CONTEXT_RECALL: "ContextualRecallMetric",
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextualRelevancyMetric",
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCorrectnessMetric",
EvaluationMetricName.TASK_COMPLETION: "TaskCompletionMetric",
}
class DeepEvalEvaluator(BaseEvaluationInstance):
"""DeepEval framework adapter for evaluation."""
def __init__(self, config: DeepEvalConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
candidates = LLM_METRIC_NAMES
case EvaluationCategory.RETRIEVAL:
candidates = RETRIEVAL_METRIC_NAMES
case EvaluationCategory.AGENT:
candidates = AGENT_METRIC_NAMES
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
candidates = WORKFLOW_METRIC_NAMES
case _:
return []
return [m for m in candidates if m in _DEEPEVAL_METRIC_MAP]
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
def _evaluate(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using DeepEval."""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metric_names or self.get_supported_metrics(category)
try:
return self._evaluate_with_deepeval(items, requested_metrics, category)
except ImportError:
logger.warning("DeepEval not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_deepeval(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using DeepEval library.
Builds LLMTestCase differently per category:
- LLM/Workflow: input=prompt, actual_output=output, retrieval_context=context
- Retrieval: input=query, actual_output=output, expected_output, retrieval_context=context
- Agent: input=query, actual_output=output
"""
metric_pairs = _build_deepeval_metrics(requested_metrics)
if not metric_pairs:
logger.warning("No valid DeepEval metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index) for item in items]
results: list[EvaluationItemResult] = []
for item in items:
test_case = self._build_test_case(item, category)
metrics: list[EvaluationMetric] = []
for canonical_name, metric in metric_pairs:
try:
metric.measure(test_case)
if metric.score is not None:
metrics.append(EvaluationMetric(name=canonical_name, value=float(metric.score)))
except Exception:
logger.exception(
"Failed to compute metric %s for item %d",
canonical_name,
item.index,
)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
@staticmethod
def _build_test_case(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
"""Build a deepeval LLMTestCase with the correct fields per category."""
from deepeval.test_case import LLMTestCase
user_input = _format_input(item.inputs, category)
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
# faithfulness needs: input, actual_output, retrieval_context
# answer_relevancy needs: input, actual_output
return LLMTestCase(
input=user_input,
actual_output=item.output,
expected_output=item.expected_output or None,
retrieval_context=item.context or None,
)
case EvaluationCategory.RETRIEVAL:
# contextual_precision/recall needs: input, actual_output, expected_output, retrieval_context
return LLMTestCase(
input=user_input,
actual_output=item.output or "",
expected_output=item.expected_output or "",
retrieval_context=item.context or [],
)
case _:
return LLMTestCase(
input=user_input,
actual_output=item.output,
)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when DeepEval is not available."""
results: list[EvaluationItemResult] = []
for item in items:
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, m_name, item)
metrics.append(EvaluationMetric(name=m_name, value=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
@staticmethod
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nInput: {item.inputs}",
f"\nOutput: {item.output}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
import re
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
"""Extract the user-facing input string from the inputs dict."""
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
return str(inputs.get("prompt", ""))
case EvaluationCategory.RETRIEVAL:
return str(inputs.get("query", ""))
case _:
return str(next(iter(inputs.values()), "")) if inputs else ""
def _build_deepeval_metrics(requested_metrics: list[str]) -> list[tuple[str, Any]]:
"""Build DeepEval metric instances from canonical metric names.
Returns a list of (canonical_name, metric_instance) pairs so that callers
can record the canonical name rather than the framework-internal class name.
"""
try:
from deepeval.metrics import (
AnswerRelevancyMetric,
ContextualPrecisionMetric,
ContextualRecallMetric,
ContextualRelevancyMetric,
FaithfulnessMetric,
TaskCompletionMetric,
ToolCorrectnessMetric,
)
# Maps canonical name → deepeval metric class
deepeval_class_map: dict[str, Any] = {
EvaluationMetricName.FAITHFULNESS: FaithfulnessMetric,
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancyMetric,
EvaluationMetricName.CONTEXT_PRECISION: ContextualPrecisionMetric,
EvaluationMetricName.CONTEXT_RECALL: ContextualRecallMetric,
EvaluationMetricName.CONTEXT_RELEVANCE: ContextualRelevancyMetric,
EvaluationMetricName.TOOL_CORRECTNESS: ToolCorrectnessMetric,
EvaluationMetricName.TASK_COMPLETION: TaskCompletionMetric,
}
pairs: list[tuple[str, Any]] = []
for name in requested_metrics:
metric_class = deepeval_class_map.get(name)
if metric_class:
pairs.append((name, metric_class(threshold=0.5)))
else:
logger.warning("Metric '%s' is not supported by DeepEval, skipping", name)
return pairs
except ImportError:
logger.warning("DeepEval metrics not available")
return []

View File

@@ -1,312 +0,0 @@
import logging
from typing import Any
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.config_entity import RagasConfig
from core.evaluation.entities.evaluation_entity import (
AGENT_METRIC_NAMES,
LLM_METRIC_NAMES,
RETRIEVAL_METRIC_NAMES,
WORKFLOW_METRIC_NAMES,
EvaluationCategory,
EvaluationItemInput,
EvaluationItemResult,
EvaluationMetric,
EvaluationMetricName,
)
from core.evaluation.frameworks.ragas.ragas_model_wrapper import DifyModelWrapper
logger = logging.getLogger(__name__)
# Maps canonical EvaluationMetricName to the corresponding ragas metric class.
# Metrics not listed here are unsupported by ragas and will be skipped.
_RAGAS_METRIC_MAP: dict[EvaluationMetricName, str] = {
EvaluationMetricName.FAITHFULNESS: "Faithfulness",
EvaluationMetricName.ANSWER_RELEVANCY: "AnswerRelevancy",
EvaluationMetricName.ANSWER_CORRECTNESS: "AnswerCorrectness",
EvaluationMetricName.SEMANTIC_SIMILARITY: "SemanticSimilarity",
EvaluationMetricName.CONTEXT_PRECISION: "ContextPrecision",
EvaluationMetricName.CONTEXT_RECALL: "ContextRecall",
EvaluationMetricName.CONTEXT_RELEVANCE: "ContextRelevance",
EvaluationMetricName.TOOL_CORRECTNESS: "ToolCallAccuracy",
}
class RagasEvaluator(BaseEvaluationInstance):
"""RAGAS framework adapter for evaluation."""
def __init__(self, config: RagasConfig):
self.config = config
def get_supported_metrics(self, category: EvaluationCategory) -> list[str]:
match category:
case EvaluationCategory.LLM:
candidates = LLM_METRIC_NAMES
case EvaluationCategory.RETRIEVAL:
candidates = RETRIEVAL_METRIC_NAMES
case EvaluationCategory.AGENT:
candidates = AGENT_METRIC_NAMES
case EvaluationCategory.WORKFLOW | EvaluationCategory.SNIPPET:
candidates = WORKFLOW_METRIC_NAMES
case _:
return []
return [m for m in candidates if m in _RAGAS_METRIC_MAP]
def evaluate_llm(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.LLM)
def evaluate_retrieval(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.RETRIEVAL)
def evaluate_agent(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.AGENT)
def evaluate_workflow(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
return self._evaluate(items, metric_names, model_provider, model_name, tenant_id, EvaluationCategory.WORKFLOW)
def _evaluate(
self,
items: list[EvaluationItemInput],
metric_names: list[str],
model_provider: str,
model_name: str,
tenant_id: str,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Core evaluation logic using RAGAS."""
model_wrapper = DifyModelWrapper(model_provider, model_name, tenant_id)
requested_metrics = metric_names or self.get_supported_metrics(category)
try:
return self._evaluate_with_ragas(items, requested_metrics, model_wrapper, category)
except ImportError:
logger.warning("RAGAS not installed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
def _evaluate_with_ragas(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
category: EvaluationCategory,
) -> list[EvaluationItemResult]:
"""Evaluate using RAGAS library.
Builds SingleTurnSample differently per category to match ragas requirements:
- LLM/Workflow: user_input=prompt, response=output, reference=expected_output
- Retrieval: user_input=query, reference=expected_output, retrieved_contexts=context
- Agent: Not supported via EvaluationDataset (requires message-based API)
"""
from ragas import evaluate as ragas_evaluate
from ragas.dataset_schema import EvaluationDataset
samples: list[Any] = []
for item in items:
sample = self._build_sample(item, category)
samples.append(sample)
dataset = EvaluationDataset(samples=samples)
ragas_metrics = self._build_ragas_metrics(requested_metrics)
if not ragas_metrics:
logger.warning("No valid RAGAS metrics found for: %s", requested_metrics)
return [EvaluationItemResult(index=item.index) for item in items]
try:
result = ragas_evaluate(
dataset=dataset,
metrics=ragas_metrics,
)
results: list[EvaluationItemResult] = []
result_df = result.to_pandas()
for i, item in enumerate(items):
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
if m_name in result_df.columns:
score = result_df.iloc[i][m_name]
if score is not None and not (isinstance(score, float) and score != score):
metrics.append(EvaluationMetric(name=m_name, value=float(score)))
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
except Exception:
logger.exception("RAGAS evaluation failed, falling back to simple evaluation")
return self._evaluate_simple(items, requested_metrics, model_wrapper)
@staticmethod
def _build_sample(item: EvaluationItemInput, category: EvaluationCategory) -> Any:
"""Build a ragas SingleTurnSample with the correct fields per category.
ragas metric field requirements:
- faithfulness: user_input, response, retrieved_contexts
- answer_relevancy: user_input, response
- answer_correctness: user_input, response, reference
- semantic_similarity: user_input, response, reference
- context_precision: user_input, reference, retrieved_contexts
- context_recall: user_input, reference, retrieved_contexts
- context_relevance: user_input, retrieved_contexts
"""
from ragas.dataset_schema import SingleTurnSample
user_input = _format_input(item.inputs, category)
match category:
case EvaluationCategory.LLM:
# response = actual LLM output, reference = expected output
return SingleTurnSample(
user_input=user_input,
response=item.output,
reference=item.expected_output or "",
retrieved_contexts=item.context or [],
)
case EvaluationCategory.RETRIEVAL:
# context_precision/recall only need reference + retrieved_contexts
return SingleTurnSample(
user_input=user_input,
reference=item.expected_output or "",
retrieved_contexts=item.context or [],
)
case _:
return SingleTurnSample(
user_input=user_input,
response=item.output,
)
def _evaluate_simple(
self,
items: list[EvaluationItemInput],
requested_metrics: list[str],
model_wrapper: DifyModelWrapper,
) -> list[EvaluationItemResult]:
"""Simple LLM-as-judge fallback when RAGAS is not available."""
results: list[EvaluationItemResult] = []
for item in items:
metrics: list[EvaluationMetric] = []
for m_name in requested_metrics:
try:
score = self._judge_with_llm(model_wrapper, m_name, item)
metrics.append(EvaluationMetric(name=m_name, value=score))
except Exception:
logger.exception("Failed to compute metric %s for item %d", m_name, item.index)
results.append(EvaluationItemResult(index=item.index, metrics=metrics))
return results
def _judge_with_llm(
self,
model_wrapper: DifyModelWrapper,
metric_name: str,
item: EvaluationItemInput,
) -> float:
"""Use the LLM to judge a single metric for a single item."""
prompt = self._build_judge_prompt(metric_name, item)
response = model_wrapper.invoke(prompt)
return self._parse_score(response)
@staticmethod
def _build_judge_prompt(metric_name: str, item: EvaluationItemInput) -> str:
"""Build a scoring prompt for the LLM judge."""
parts = [
f"Evaluate the following on the metric '{metric_name}' using a scale of 0.0 to 1.0.",
f"\nInput: {item.inputs}",
f"\nOutput: {item.output}",
]
if item.expected_output:
parts.append(f"\nExpected Output: {item.expected_output}")
if item.context:
parts.append(f"\nContext: {'; '.join(item.context)}")
parts.append("\nRespond with ONLY a single floating point number between 0.0 and 1.0, nothing else.")
return "\n".join(parts)
@staticmethod
def _parse_score(response: str) -> float:
"""Parse a float score from LLM response."""
import re
cleaned = response.strip()
try:
score = float(cleaned)
return max(0.0, min(1.0, score))
except ValueError:
match = re.search(r"(\d+\.?\d*)", cleaned)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
return 0.0
@staticmethod
def _build_ragas_metrics(requested_metrics: list[str]) -> list[Any]:
"""Build RAGAS metric instances from canonical metric names."""
try:
from ragas.metrics.collections import (
AnswerCorrectness,
AnswerRelevancy,
ContextPrecision,
ContextRecall,
ContextRelevance,
Faithfulness,
SemanticSimilarity,
ToolCallAccuracy,
)
# Maps canonical name → ragas metric class
ragas_class_map: dict[str, Any] = {
EvaluationMetricName.FAITHFULNESS: Faithfulness,
EvaluationMetricName.ANSWER_RELEVANCY: AnswerRelevancy,
EvaluationMetricName.ANSWER_CORRECTNESS: AnswerCorrectness,
EvaluationMetricName.SEMANTIC_SIMILARITY: SemanticSimilarity,
EvaluationMetricName.CONTEXT_PRECISION: ContextPrecision,
EvaluationMetricName.CONTEXT_RECALL: ContextRecall,
EvaluationMetricName.CONTEXT_RELEVANCE: ContextRelevance,
EvaluationMetricName.TOOL_CORRECTNESS: ToolCallAccuracy,
}
metrics = []
for name in requested_metrics:
metric_class = ragas_class_map.get(name)
if metric_class:
metrics.append(metric_class())
else:
logger.warning("Metric '%s' is not supported by RAGAS, skipping", name)
return metrics
except ImportError:
logger.warning("RAGAS metrics not available")
return []
def _format_input(inputs: dict[str, Any], category: EvaluationCategory) -> str:
"""Extract the user-facing input string from the inputs dict."""
match category:
case EvaluationCategory.LLM | EvaluationCategory.WORKFLOW:
return str(inputs.get("prompt", ""))
case EvaluationCategory.RETRIEVAL:
return str(inputs.get("query", ""))
case _:
return str(next(iter(inputs.values()), "")) if inputs else ""

View File

@@ -1,48 +0,0 @@
import logging
from typing import Any
logger = logging.getLogger(__name__)
class DifyModelWrapper:
"""Wraps Dify's model invocation interface for use by RAGAS as an LLM judge.
RAGAS requires an LLM to compute certain metrics (faithfulness, answer_relevancy, etc.).
This wrapper bridges Dify's ModelInstance to a callable that RAGAS can use.
"""
def __init__(self, model_provider: str, model_name: str, tenant_id: str):
self.model_provider = model_provider
self.model_name = model_name
self.tenant_id = tenant_id
def _get_model_instance(self) -> Any:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.model_provider,
model_type=ModelType.LLM,
model=self.model_name,
)
return model_instance
def invoke(self, prompt: str) -> str:
"""Invoke the model with a text prompt and return the text response."""
from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
model_instance = self._get_model_instance()
result = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content="You are an evaluation judge. Answer precisely and concisely."),
UserPromptMessage(content=prompt),
],
model_parameters={"temperature": 0.0, "max_tokens": 2048},
stream=False,
)
return result.message.content

View File

@@ -1,294 +0,0 @@
"""Judgment condition processor for evaluation metrics.
Evaluates pass/fail judgment conditions against evaluation metric values.
Each condition uses:
- ``metric_name`` as the left-hand side lookup key from ``metric_values``
- ``comparison_operator`` as the operator
- ``condition_value`` as the right-hand side comparison value
The processor is intentionally decoupled from evaluation frameworks and
runners. It operates on plain ``dict`` mappings and can be invoked anywhere
that already has per-item metric results.
"""
import logging
from collections.abc import Sequence
from datetime import datetime
from typing import Any, cast
from core.evaluation.entities.judgment_entity import (
JudgmentCondition,
JudgmentConditionResult,
JudgmentConditionType,
JudgmentConfig,
JudgmentResult,
)
from graphon.utils.condition.entities import SupportedComparisonOperator
from graphon.utils.condition.processor import _evaluate_condition # pyright: ignore[reportPrivateUsage]
logger = logging.getLogger(__name__)
# Operators that do not need a comparison value (unary operators).
_UNARY_OPERATORS = frozenset({"null", "not null", "empty", "not empty"})
class JudgmentProcessor:
@staticmethod
def evaluate(
metric_values: dict[str, Any],
config: JudgmentConfig,
) -> JudgmentResult:
"""Evaluate all judgment conditions against the given metric values.
Args:
metric_values: Mapping of metric name → metric value
(e.g. ``{"faithfulness": 0.85, "status": "success"}``).
config: The judgment configuration with logical_operator and conditions.
Returns:
JudgmentResult with overall pass/fail and per-condition details.
"""
if not config.conditions:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=[],
)
condition_results: list[JudgmentConditionResult] = []
for condition in config.conditions:
result = JudgmentProcessor._evaluate_single_condition(metric_values, condition)
condition_results.append(result)
if config.logical_operator == "and" and not result.passed:
return JudgmentResult(
passed=False,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
if config.logical_operator == "or" and result.passed:
return JudgmentResult(
passed=True,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
# All conditions evaluated
if config.logical_operator == "and":
final_passed = all(r.passed for r in condition_results)
else:
final_passed = any(r.passed for r in condition_results)
return JudgmentResult(
passed=final_passed,
logical_operator=config.logical_operator,
condition_results=condition_results,
)
@staticmethod
def _evaluate_single_condition(
metric_values: dict[str, Any],
condition: JudgmentCondition,
) -> JudgmentConditionResult:
"""Evaluate a single judgment condition.
Steps:
1. Look up the metric value (left side) by ``metric_name``.
2. Read ``condition_value`` as the comparison value (right side).
3. Dispatch to the correct type handler (string / number / datetime).
"""
metric_name = condition.metric_name
actual_value = metric_values.get(metric_name)
# Handle metric not found — skip for unary operators that work on None
if actual_value is None and condition.comparison_operator not in _UNARY_OPERATORS:
return JudgmentConditionResult(
metric_name=metric_name,
comparison_operator=condition.comparison_operator,
expected_value=condition.condition_value,
actual_value=None,
passed=False,
error=f"Metric '{metric_name}' not found in evaluation results",
)
resolved_value = condition.condition_value
# Dispatch to the appropriate type handler
try:
match condition.condition_type:
case JudgmentConditionType.DATETIME:
passed = _evaluate_datetime_condition(actual_value, condition.comparison_operator, resolved_value)
case JudgmentConditionType.NUMBER:
passed = _evaluate_number_condition(actual_value, condition.comparison_operator, resolved_value)
case _: # STRING (default) — delegate to workflow engine
if condition.comparison_operator in {"before", "after"}:
raise ValueError(
f"Operator '{condition.comparison_operator}' is not supported for string conditions"
)
passed = _evaluate_condition(
operator=cast(SupportedComparisonOperator, condition.comparison_operator),
value=actual_value,
expected=resolved_value,
)
return JudgmentConditionResult(
metric_name=metric_name,
comparison_operator=condition.comparison_operator,
expected_value=resolved_value,
actual_value=actual_value,
passed=passed,
)
except Exception as e:
logger.warning(
"Judgment condition evaluation failed for metric '%s': %s",
metric_name,
str(e),
)
return JudgmentConditionResult(
metric_name=metric_name,
comparison_operator=condition.comparison_operator,
expected_value=resolved_value,
actual_value=actual_value,
passed=False,
error=str(e),
)
_DATETIME_FORMATS = [
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%SZ",
"%Y-%m-%dT%H:%M:%S.%fZ",
"%Y-%m-%dT%H:%M:%S%z",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d",
]
def _parse_datetime(value: object) -> datetime:
"""Parse a value into a datetime object.
Accepts datetime instances, numeric timestamps (int/float), and common
ISO 8601 string formats.
Raises:
ValueError: If the value cannot be parsed as a datetime.
"""
if isinstance(value, datetime):
return value
if isinstance(value, (int, float)):
return datetime.fromtimestamp(value)
if not isinstance(value, str):
raise ValueError(f"Cannot parse '{value}' (type={type(value).__name__}) as datetime")
for fmt in _DATETIME_FORMATS:
try:
return datetime.strptime(value, fmt)
except ValueError:
continue
raise ValueError(
f"Cannot parse datetime string '{value}'. "
f"Supported formats: ISO 8601, 'YYYY-MM-DD HH:MM:SS', 'YYYY-MM-DD', or numeric timestamp."
)
def _evaluate_datetime_condition(
actual: object,
operator: str,
expected: object,
) -> bool:
"""Evaluate a datetime comparison condition.
Also supports the universal unary operators (null, not null, empty, not empty)
and the numeric-style operators (=, ≠, >, <, ≥, ≤) for datetime values.
Args:
actual: The actual metric value (left side).
operator: The comparison operator.
expected: The expected/threshold value (right side).
Returns:
True if the condition passes.
Raises:
ValueError: If values cannot be parsed or operator is unsupported.
"""
# Handle unary operators first
if operator == "null":
return actual is None
if operator == "not null":
return actual is not None
if operator == "empty":
return not actual
if operator == "not empty":
return bool(actual)
if actual is None:
return False
actual_dt = _parse_datetime(actual)
expected_dt = _parse_datetime(expected) if expected is not None else None
if expected_dt is None:
raise ValueError(f"Expected datetime value is required for operator '{operator}'")
match operator:
case "before" | "<":
return actual_dt < expected_dt
case "after" | ">":
return actual_dt > expected_dt
case "=" | "is":
return actual_dt == expected_dt
case "" | "is not":
return actual_dt != expected_dt
case "":
return actual_dt >= expected_dt
case "":
return actual_dt <= expected_dt
case _:
raise ValueError(f"Unsupported datetime operator: '{operator}'")
def _evaluate_number_condition(
actual: object,
operator: str,
expected: object,
) -> bool:
"""Evaluate a numeric comparison condition.
Ensures proper numeric type coercion before delegating to the workflow
condition engine. This avoids string-vs-number comparison pitfalls
(e.g. comparing float metric 0.85 against string threshold "0.8").
For unary operators (null, not null, empty, not empty), delegates directly.
"""
# Unary operators — delegate to workflow engine as-is
if operator in _UNARY_OPERATORS:
return _evaluate_condition(
operator=cast(SupportedComparisonOperator, operator),
value=actual,
expected=cast(str | Sequence[str] | bool | Sequence[bool] | None, expected),
)
if actual is None:
return False
# Coerce actual to numeric
if not isinstance(actual, (int, float)):
try:
actual = float(cast(str | int | float, actual))
except (TypeError, ValueError) as e:
raise ValueError(f"Cannot convert actual value '{actual}' to number") from e
# Coerce expected to numeric string for the workflow engine
# (the workflow engine's _normalize_numeric_values handles str → float)
if expected is not None and not isinstance(expected, str):
expected = str(expected)
return _evaluate_condition(
operator=cast(SupportedComparisonOperator, operator),
value=actual,
expected=expected,
)

View File

@@ -1,52 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import Account, App, CustomizedSnippet, TenantAccountJoin
def get_service_account_for_app(session: Session, app_id: str) -> Account:
"""Get the creator account for an app with tenant context set up.
This follows the same pattern as BaseTraceInstance.get_service_account_with_tenant().
"""
app = session.scalar(select(App).where(App.id == app_id))
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account = session.scalar(select(Account).where(Account.id == app.created_by))
if not account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account
def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account:
"""Get the creator account for a snippet with tenant context set up.
Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet.
"""
snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id))
if not snippet:
raise ValueError(f"Snippet with id {snippet_id} not found")
if not snippet.created_by:
raise ValueError(f"Snippet with id {snippet_id} has no creator")
account = session.scalar(select(Account).where(Account.id == snippet.created_by))
if not account:
raise ValueError(f"Creator account not found for snippet {snippet_id}")
current_tenant = session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account

View File

@@ -1,154 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
from models.model import App
logger = logging.getLogger(__name__)
class AgentEvaluationRunner(BaseEvaluationRunner):
"""Runner for agent evaluation: executes agent-type App, collects tool calls and final output."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def execute_target(
self,
tenant_id: str,
target_id: str,
target_type: str,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Execute agent app and collect response with tool call information."""
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_app
app = self.session.query(App).filter_by(id=target_id).first()
if not app:
raise ValueError(f"App {target_id} not found")
service_account = get_service_account_for_app(self.session, target_id)
query = self._extract_query(item.inputs)
args: dict[str, Any] = {
"inputs": item.inputs,
"query": query,
}
generator = AgentChatAppGenerator()
# Agent chat requires streaming - collect full response
response_generator = generator.generate(
app_model=app,
user=service_account,
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=True,
)
# Consume the stream to get the full response
actual_output, tool_calls = self._consume_agent_stream(response_generator)
return EvaluationItemResult(
index=item.index,
actual_output=actual_output,
metadata={"tool_calls": tool_calls},
)
def evaluate_metrics(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None,
node_run_result_list: list[NodeRunResult] | None,
default_metric: DefaultMetric | None,
customized_metrics: CustomizedMetrics | None,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute agent evaluation metrics."""
if not node_run_result_list:
return []
if not default_metric:
raise ValueError("Default metric is required for agent evaluation")
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_agent(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for agent evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_agent_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""
@staticmethod
def _consume_agent_stream(response_generator: Any) -> tuple[str, list[dict]]:
"""Consume agent streaming response and extract final answer + tool calls."""
answer_parts: list[str] = []
tool_calls: list[dict] = []
try:
for chunk in response_generator:
if isinstance(chunk, Mapping):
event = chunk.get("event")
if event == "agent_thought":
thought = chunk.get("thought", "")
if thought:
answer_parts.append(thought)
tool = chunk.get("tool")
if tool:
tool_calls.append(
{
"tool": tool,
"tool_input": chunk.get("tool_input", ""),
}
)
elif event == "message":
answer = chunk.get("answer", "")
if answer:
answer_parts.append(answer)
elif isinstance(chunk, str):
answer_parts.append(chunk)
except Exception:
logger.exception("Error consuming agent stream")
return "".join(answer_parts), tool_calls
def _extract_agent_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from agent NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@@ -1,179 +0,0 @@
"""Base evaluation runner.
Orchestrates the evaluation lifecycle in four phases:
1. execute_target — run the target and collect actual outputs (abstract)
2. evaluate_metrics — compute metrics via framework or customized workflow
3. apply_judgment — evaluate pass/fail judgment conditions on metrics
4. persist — save results to the database
The persisted ``EvaluationRunItem.judgment`` payload must reflect the final
judgment result for each evaluated item, so judgment evaluation happens before
the persistence phase whenever a ``JudgmentConfig`` is supplied.
"""
import json
import logging
from abc import ABC, abstractmethod
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
DefaultMetric,
EvaluationDatasetInput,
EvaluationItemResult,
)
from core.evaluation.entities.judgment_entity import JudgmentConfig
from core.evaluation.judgment.processor import JudgmentProcessor
from graphon.node_events import NodeRunResult
from libs.datetime_utils import naive_utc_now
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
logger = logging.getLogger(__name__)
class BaseEvaluationRunner(ABC):
"""Abstract base class for evaluation runners."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
self.evaluation_instance = evaluation_instance
self.session = session
@abstractmethod
def evaluate_metrics(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None,
node_run_result_list: list[NodeRunResult] | None,
default_metric: DefaultMetric | None,
customized_metrics: CustomizedMetrics | None,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics on the collected results."""
...
def run(
self,
evaluation_run_id: str,
tenant_id: str,
target_id: str,
target_type: str,
node_run_result_list: list[NodeRunResult] | None = None,
default_metric: DefaultMetric | None = None,
customized_metrics: CustomizedMetrics | None = None,
model_provider: str = "",
model_name: str = "",
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None = None,
judgment_config: JudgmentConfig | None = None,
input_list: list[EvaluationDatasetInput] | None = None,
) -> list[EvaluationItemResult]:
"""Orchestrate target execution + metric evaluation + judgment for all items."""
evaluation_run = self.session.query(EvaluationRun).filter_by(id=evaluation_run_id).first()
if not evaluation_run:
raise ValueError(f"EvaluationRun {evaluation_run_id} not found")
if not default_metric and not customized_metrics:
raise ValueError("Either default_metric or customized_metrics must be provided")
# Update status to running
evaluation_run.status = EvaluationRunStatus.RUNNING
evaluation_run.started_at = naive_utc_now()
self.session.commit()
results_by_index: dict[int, EvaluationItemResult] = {}
# Phase 1: run evaluation
if default_metric and node_run_result_list:
try:
evaluated_results = self.evaluate_metrics(
node_run_result_mapping_list=node_run_result_mapping_list,
node_run_result_list=node_run_result_list,
default_metric=default_metric,
customized_metrics=customized_metrics,
model_provider=model_provider,
model_name=model_name,
tenant_id=tenant_id,
)
for r in evaluated_results:
results_by_index[r.index] = r
except Exception:
logger.exception("Failed to compute metrics for evaluation run %s", evaluation_run_id)
if customized_metrics and node_run_result_mapping_list:
try:
customized_results = self.evaluation_instance.evaluate_with_customized_workflow(
node_run_result_mapping_list=node_run_result_mapping_list,
customized_metrics=customized_metrics,
tenant_id=tenant_id,
)
for r in customized_results:
existing = results_by_index.get(r.index)
if existing:
# Merge: combine metrics from both sources into one result
results_by_index[r.index] = existing.model_copy(
update={"metrics": existing.metrics + r.metrics}
)
else:
results_by_index[r.index] = r
except Exception:
logger.exception("Failed to compute customized metrics for evaluation run %s", evaluation_run_id)
results = list(results_by_index.values())
if judgment_config is not None:
results = self._apply_judgment(
results=results,
judgment_config=judgment_config,
node_run_result_mapping_list=node_run_result_mapping_list,
)
# Phase 4: Persist individual items
dataset_items = input_list or []
for result in results:
item_input = next((item for item in dataset_items if item.index == result.index), None)
run_item = EvaluationRunItem(
evaluation_run_id=evaluation_run_id,
item_index=result.index,
inputs=json.dumps(item_input.inputs) if item_input else None,
expected_output=item_input.expected_output if item_input else None,
context=json.dumps(item_input.context) if item_input and getattr(item_input, "context", None) else None,
actual_output=result.actual_output,
metrics=json.dumps([m.model_dump() for m in result.metrics]) if result.metrics else None,
judgment=json.dumps(result.judgment.model_dump()) if result.judgment else None,
metadata_json=json.dumps(result.metadata) if result.metadata else None,
error=result.error,
overall_score=getattr(result, "overall_score", None),
)
self.session.add(run_item)
self.session.commit()
return results
@staticmethod
def _apply_judgment(
results: list[EvaluationItemResult],
judgment_config: JudgmentConfig,
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None = None,
) -> list[EvaluationItemResult]:
"""Apply judgment conditions to each result's metrics.
Judgment is computed only from the per-item metric values and the
supplied ``JudgmentConfig``. ``metric_name`` selects the left-hand side
metric, and ``condition_value`` is used as the comparison target.
"""
judged_results: list[EvaluationItemResult] = []
for result in results:
if result.error is not None or not result.metrics:
judged_results.append(result)
continue
# Left side: only metrics
metric_values: dict[str, object] = {m.name: m.value for m in result.metrics}
judgment_result = JudgmentProcessor.evaluate(metric_values, judgment_config)
judged_results.append(result.model_copy(update={"judgment": judgment_result}))
return judged_results

View File

@@ -1,119 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any, Union
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class LLMEvaluationRunner(BaseEvaluationRunner):
"""Runner for LLM evaluation: executes App to get responses, then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def evaluate_metrics(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None,
node_run_result_list: list[NodeRunResult] | None,
default_metric: DefaultMetric | None,
customized_metrics: CustomizedMetrics | None,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Use the evaluation instance to compute LLM metrics."""
# Merge actual_output into items for evaluation
if not node_run_result_list:
return []
if not default_metric:
raise ValueError("Default metric is required for LLM evaluation")
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_llm(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
"""Extract query from inputs."""
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""
@staticmethod
def _extract_output(response: Union[Mapping[str, Any], Any]) -> str:
"""Extract text output from app response."""
if isinstance(response, Mapping):
# Workflow response
if "data" in response and isinstance(response["data"], Mapping):
outputs = response["data"].get("outputs", {})
if isinstance(outputs, Mapping):
values = list(outputs.values())
return str(values[0]) if values else ""
return str(outputs)
# Completion response
if "answer" in response:
return str(response["answer"])
if "text" in response:
return str(response["text"])
return str(response)
@staticmethod
def _merge_results_into_items(
items: list[NodeRunResult],
) -> list[EvaluationItemInput]:
"""Create new items from NodeRunResult for ragas evaluation.
Extracts prompts from process_data and concatenates them into a single
string with role prefixes (e.g. "system: ...\nuser: ...\nassistant: ...").
The last assistant message in outputs is used as the actual output.
"""
merged = []
for i, item in enumerate(items):
prompt = _format_prompts(item.process_data.get("prompts", []))
output = _extract_llm_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs={"prompt": prompt},
output=output,
)
)
return merged
def _format_prompts(prompts: list[dict[str, Any]]) -> str:
"""Concatenate a list of prompt messages into a single string for evaluation.
Each message is formatted as "role: text" and joined with newlines.
"""
parts: list[str] = []
for msg in prompts:
role = msg.get("role", "unknown")
text = msg.get("text", "")
parts.append(f"{role}: {text}")
return "\n".join(parts)
def _extract_llm_output(outputs: Mapping[str, Any]) -> str:
"""Extract the LLM output text from NodeRunResult.outputs."""
if "text" in outputs:
return str(outputs["text"])
if "answer" in outputs:
return str(outputs["answer"])
# Fallback: first value
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@@ -1,68 +0,0 @@
import logging
from typing import Any
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class RetrievalEvaluationRunner(BaseEvaluationRunner):
"""Runner for retrieval evaluation: performs knowledge base retrieval, then evaluates."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def evaluate_metrics(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None,
node_run_result_list: list[NodeRunResult] | None,
default_metric: DefaultMetric | None,
customized_metrics: CustomizedMetrics | None,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute retrieval evaluation metrics."""
if not node_run_result_list:
return []
merged_items = []
for i, node_result in enumerate(node_run_result_list):
# Extract retrieved contexts from outputs
outputs = node_result.outputs
query = self._extract_query(dict(node_result.inputs))
# Extract retrieved content from result list
result_list = outputs.get("result", [])
contexts = [item.get("content", "") for item in result_list if item.get("content")]
output = "\n---\n".join(contexts)
merged_items.append(
EvaluationItemInput(
index=i,
inputs={"query": query},
output=output,
context=contexts,
)
)
return self.evaluation_instance.evaluate_retrieval(
merged_items, [default_metric.metric]if default_metric else [], model_provider, model_name, tenant_id
)
@staticmethod
def _extract_query(inputs: dict[str, Any]) -> str:
for key in ("query", "question", "input", "text"):
if key in inputs:
return str(inputs[key])
values = list(inputs.values())
return str(values[0]) if values else ""

View File

@@ -1,232 +0,0 @@
"""Runner for Snippet evaluation.
Executes a published Snippet workflow in non-streaming mode, collects the
actual outputs and per-node execution records, then delegates to the
evaluation instance for metric computation.
"""
import json
import logging
from collections.abc import Mapping, Sequence
from typing import Any
from sqlalchemy import asc, select
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
from models.snippet import CustomizedSnippet
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
class SnippetEvaluationRunner(BaseEvaluationRunner):
"""Runner for snippet evaluation: executes a published Snippet workflow."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def execute_target(
self,
tenant_id: str,
target_id: str,
target_type: str,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Execute a published Snippet workflow and collect outputs.
Steps:
1. Delegate execution to ``SnippetGenerateService.run_published``.
2. Extract ``workflow_run_id`` from the blocking response.
3. Query ``workflow_node_executions`` by ``workflow_run_id`` to get
each node's inputs, outputs, status, elapsed_time, etc.
4. Return result with actual_output and node_executions metadata.
"""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_snippet
from services.snippet_generate_service import SnippetGenerateService
snippet = self.session.query(CustomizedSnippet).filter_by(id=target_id).first()
if not snippet:
raise ValueError(f"Snippet {target_id} not found")
if not snippet.is_published:
raise ValueError(f"Snippet {target_id} is not published")
service_account = get_service_account_for_snippet(self.session, target_id)
response = SnippetGenerateService.run_published(
snippet=snippet,
user=service_account,
args={"inputs": item.inputs},
invoke_from=InvokeFrom.SERVICE_API,
)
actual_output = self._extract_output(response)
# Retrieve per-node execution records from DB
workflow_run_id = self._extract_workflow_run_id(response)
node_executions = (
self._query_node_executions(
tenant_id=tenant_id,
app_id=target_id,
workflow_run_id=workflow_run_id,
)
if workflow_run_id
else []
)
return EvaluationItemResult(
index=item.index,
actual_output=actual_output,
metadata={
"workflow_run_id": workflow_run_id or "",
"node_executions": node_executions,
},
)
def evaluate_metrics(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None,
node_run_result_list: list[NodeRunResult] | None,
default_metric: DefaultMetric | None,
customized_metrics: CustomizedMetrics | None,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics for snippet outputs.
Snippets are essentially workflows, so we reuse evaluate_workflow from
the evaluation instance.
"""
if not node_run_result_list:
return []
if not default_metric:
raise ValueError("Default metric is required for snippet evaluation")
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_workflow(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for snippet evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_snippet_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
@staticmethod
def _extract_output(response: Mapping[str, Any]) -> str:
"""Extract text output from the blocking workflow response.
The blocking response ``data.outputs`` is a dict of output variables.
We take the first value as the primary output text.
"""
if "data" in response and isinstance(response["data"], Mapping):
outputs = response["data"].get("outputs", {})
if isinstance(outputs, Mapping):
values = list(outputs.values())
return str(values[0]) if values else ""
return str(outputs)
return str(response)
@staticmethod
def _extract_workflow_run_id(response: Mapping[str, Any]) -> str | None:
"""Extract workflow_run_id from the blocking response.
The blocking response has ``workflow_run_id`` at the top level and
also ``data.id`` (same value).
"""
wf_run_id = response.get("workflow_run_id")
if wf_run_id:
return str(wf_run_id)
# Fallback to data.id
data = response.get("data")
if isinstance(data, Mapping) and data.get("id"):
return str(data["id"])
return None
def _query_node_executions(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> list[dict[str, Any]]:
"""Query per-node execution records from the DB after workflow completes.
Node executions are persisted during workflow execution. We read them
back via the ``workflow_run_id`` to get each node's inputs, outputs,
status, elapsed_time, etc.
Returns a list of serialisable dicts for storage in ``metadata``.
"""
stmt = (
WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel))
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
.order_by(asc(WorkflowNodeExecutionModel.created_at))
)
node_models: Sequence[WorkflowNodeExecutionModel] = self.session.execute(stmt).scalars().all()
return [self._serialize_node_execution(node) for node in node_models]
@staticmethod
def _serialize_node_execution(node: WorkflowNodeExecutionModel) -> dict[str, Any]:
"""Convert a WorkflowNodeExecutionModel to a serialisable dict.
Includes the node's id, type, title, inputs/outputs (parsed from JSON),
status, error, and elapsed_time. The virtual Start node injected by
SnippetGenerateService is filtered out by the caller if needed.
"""
def _safe_parse_json(value: str | None) -> Any:
if not value:
return None
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
return {
"id": node.id,
"node_id": node.node_id,
"node_type": node.node_type,
"title": node.title,
"inputs": _safe_parse_json(node.inputs),
"outputs": _safe_parse_json(node.outputs),
"status": node.status,
"error": node.error,
"elapsed_time": node.elapsed_time,
}
def _extract_snippet_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from snippet NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@@ -1,88 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
CustomizedMetrics,
DefaultMetric,
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from graphon.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class WorkflowEvaluationRunner(BaseEvaluationRunner):
"""Runner for workflow evaluation: executes workflow App in non-streaming mode."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def evaluate_metrics(
self,
node_run_result_mapping_list: list[dict[str, NodeRunResult]] | None,
node_run_result_list: list[NodeRunResult] | None,
default_metric: DefaultMetric | None,
customized_metrics: CustomizedMetrics | None,
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute workflow evaluation metrics (end-to-end)."""
if not node_run_result_list:
return []
if not default_metric:
raise ValueError("Default metric is required for workflow evaluation")
merged_items = self._merge_results_into_items(node_run_result_list)
return self.evaluation_instance.evaluate_workflow(
merged_items, [default_metric.metric], model_provider, model_name, tenant_id
)
@staticmethod
def _merge_results_into_items(items: list[NodeRunResult]) -> list[EvaluationItemInput]:
"""Create EvaluationItemInput list from NodeRunResult for workflow evaluation."""
merged = []
for i, item in enumerate(items):
output = _extract_workflow_output(item.outputs)
merged.append(
EvaluationItemInput(
index=i,
inputs=dict(item.inputs),
output=output,
)
)
return merged
@staticmethod
def _extract_output(response: Mapping[str, Any]) -> str:
"""Extract text output from workflow response."""
if "data" in response and isinstance(response["data"], Mapping):
outputs = response["data"].get("outputs", {})
if isinstance(outputs, Mapping):
values = list(outputs.values())
return str(values[0]) if values else ""
return str(outputs)
return str(response)
@staticmethod
def _extract_node_executions(response: Mapping[str, Any]) -> list[dict]:
"""Extract node execution trace from workflow response."""
data = response.get("data", {})
if isinstance(data, Mapping):
return data.get("node_executions", [])
return []
def _extract_workflow_output(outputs: Mapping[str, Any]) -> str:
"""Extract the primary output text from workflow NodeRunResult.outputs."""
if "answer" in outputs:
return str(outputs["answer"])
if "text" in outputs:
return str(outputs["text"])
values = list(outputs.values())
return str(values[0]) if values else ""

View File

@@ -39,6 +39,7 @@ from core.ops.entities.trace_entity import (
)
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@@ -300,7 +301,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"app_name": node_execution.title,
"status": node_execution.status,
"status_message": node_execution.error or "",
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
"level": "ERROR" if node_execution.status == WorkflowNodeExecutionStatus.FAILED else "DEFAULT",
}
)
@@ -361,7 +362,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally:
if node_execution.status == "failed":
if node_execution.status == WorkflowNodeExecutionStatus.FAILED:
set_span_status(node_span, node_execution.error)
else:
set_span_status(node_span)

View File

@@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
class BaseTraceInfo(BaseModel):
message_id: str | None = None
message_data: Any | None = None
inputs: Union[str, dict[str, Any], list] | None = None
outputs: Union[str, dict[str, Any], list] | None = None
inputs: Union[str, dict[str, Any], list[Any]] | None = None
outputs: Union[str, dict[str, Any], list[Any]] | None = None
start_time: datetime | None = None
end_time: datetime | None = None
metadata: dict[str, Any]
@@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
@field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v):
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
if v is None:
return None
if isinstance(v, str | dict | list):
@@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@property
def resolved_trace_id(self) -> str | None:
"""Get trace_id with intelligent fallback.
Priority:
1. External trace_id (from X-Trace-Id header)
2. workflow_run_id (if this trace type has it)
3. message_id (as final fallback)
"""
if self.trace_id:
return self.trace_id
# Try workflow_run_id (only exists on workflow-related traces)
workflow_run_id = getattr(self, "workflow_run_id", None)
if workflow_run_id:
return workflow_run_id
# Final fallback to message_id
return str(self.message_id) if self.message_id else None
@property
def resolved_parent_context(self) -> tuple[str | None, str | None]:
"""Resolve cross-workflow parent linking from metadata.
Extracts typed parent IDs from the untyped ``parent_trace_context``
metadata dict (set by tool_node when invoking nested workflows).
Returns:
(trace_correlation_override, parent_span_id_source) where
trace_correlation_override is the outer workflow_run_id and
parent_span_id_source is the outer node_execution_id.
"""
parent_ctx = self.metadata.get("parent_trace_context")
if not isinstance(parent_ctx, dict):
return None, None
trace_override = parent_ctx.get("parent_workflow_run_id")
parent_span = parent_ctx.get("parent_node_execution_id")
return (
trace_override if isinstance(trace_override, str) else None,
parent_span if isinstance(parent_span, str) else None,
)
@field_serializer("start_time", "end_time")
def serialize_datetime(self, dt: datetime | None) -> str | None:
if dt is None:
@@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo):
workflow_run_version: str
error: str | None = None
total_tokens: int
prompt_tokens: int | None = None
completion_tokens: int | None = None
file_list: list[str]
invoked_by: str | None = None
query: str
metadata: dict[str, Any]
@@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo):
answer_tokens: int
total_tokens: int
error: str | None = None
file_list: Union[str, dict[str, Any], list] | None = None
file_list: Union[str, dict[str, Any], list[Any]] | None = None
message_file_data: Any | None = None
conversation_mode: str
gen_ai_server_time_to_first_token: float | None = None
@@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo):
tool_config: dict[str, Any]
time_cost: Union[int, float]
tool_parameters: dict[str, Any]
file_url: Union[str, None, list] = None
file_url: Union[str, None, list[str]] = None
class GenerateNameTraceInfo(BaseTraceInfo):
@@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo):
tenant_id: str
class PromptGenerationTraceInfo(BaseTraceInfo):
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
tenant_id: str
user_id: str
app_id: str | None = None
operation_type: str
instruction: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
model_provider: str
model_name: str
latency: float
total_price: float | None = None
currency: str | None = None
error: str | None = None
model_config = ConfigDict(protected_namespaces=())
class WorkflowNodeTraceInfo(BaseTraceInfo):
workflow_id: str
workflow_run_id: str
tenant_id: str
node_execution_id: str
node_id: str
node_type: str
title: str
status: str
error: str | None = None
elapsed_time: float
index: int
predecessor_node_id: str | None = None
total_tokens: int = 0
total_price: float = 0.0
currency: str | None = None
model_provider: str | None = None
model_name: str | None = None
prompt_tokens: int | None = None
completion_tokens: int | None = None
tool_name: str | None = None
iteration_id: str | None = None
iteration_index: int | None = None
loop_id: str | None = None
loop_index: int | None = None
parallel_id: str | None = None
node_inputs: Mapping[str, Any] | None = None
node_outputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
invoked_by: str | None = None
model_config = ConfigDict(protected_namespaces=())
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
pass
class TaskData(BaseModel):
app_id: str
trace_info_type: str
@@ -128,11 +246,31 @@ trace_info_info_map = {
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
"ToolTraceInfo": ToolTraceInfo,
"GenerateNameTraceInfo": GenerateNameTraceInfo,
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
}
class OperationType(StrEnum):
"""Operation type for token metric labels.
Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output``
counters so consumers can break down token usage by operation.
"""
WORKFLOW = "workflow"
NODE_EXECUTION = "node_execution"
MESSAGE = "message"
RULE_GENERATE = "rule_generate"
CODE_GENERATE = "code_generate"
STRUCTURED_OUTPUT = "structured_output"
INSTRUCTION_MODIFY = "instruction_modify"
class TraceTaskName(StrEnum):
CONVERSATION_TRACE = "conversation"
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
WORKFLOW_TRACE = "workflow"
MESSAGE_TRACE = "message"
MODERATION_TRACE = "moderation"
@@ -140,4 +278,6 @@ class TraceTaskName(StrEnum):
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
PROMPT_GENERATION_TRACE = "prompt_generation"
NODE_EXECUTION_TRACE = "node_execution"
DATASOURCE_TRACE = "datasource"

View File

@@ -15,22 +15,32 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.engine import db
from models.account import Tenant
from models.dataset import Dataset
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from models.workflow import WorkflowAppLog
from tasks.ops_trace_task import process_trace_tasks
@@ -40,9 +50,144 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
app_name = ""
workspace_name = ""
if not app_id and not tenant_id:
return app_name, workspace_name
with Session(db.engine) as session:
if app_id:
name = session.scalar(select(App.name).where(App.id == app_id))
if name:
app_name = name
if tenant_id:
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
if name:
workspace_name = name
return app_name, workspace_name
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
"builtin": BuiltinToolProvider,
"plugin": BuiltinToolProvider,
"api": ApiToolProvider,
"workflow": WorkflowToolProvider,
"mcp": MCPToolProvider,
}
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
if not credential_id:
return ""
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
if not model_cls:
return ""
with Session(db.engine) as session:
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined]
return str(name) if name else ""
def _lookup_llm_credential_info(
tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm"
) -> tuple[str | None, str]:
"""
Lookup LLM credential ID and name for the given provider and model.
Returns (credential_id, credential_name).
Handles async timing issues gracefully - if credential is deleted between lookups,
returns the ID but empty name rather than failing.
"""
if not tenant_id or not provider:
return None, ""
try:
with Session(db.engine) as session:
# Try to find provider-level or model-level configuration
provider_record = session.scalar(
select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.CUSTOM,
)
)
if not provider_record:
return None, ""
# Check if there's a model-specific config
credential_id = None
credential_name = ""
is_model_level = False
if model:
# Try model-level first
model_record = session.scalar(
select(ProviderModel).where(
ProviderModel.tenant_id == tenant_id,
ProviderModel.provider_name == provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type,
)
)
if model_record and model_record.credential_id:
credential_id = model_record.credential_id
is_model_level = True
if not credential_id and provider_record.credential_id:
# Fall back to provider-level credential
credential_id = provider_record.credential_id
is_model_level = False
# Lookup credential_name if we have credential_id
if credential_id:
try:
if is_model_level:
# Query ProviderModelCredential
cred_name = session.scalar(
select(ProviderModelCredential.credential_name).where(
ProviderModelCredential.id == credential_id
)
)
else:
# Query ProviderCredential
cred_name = session.scalar(
select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id)
)
if cred_name:
credential_name = str(cred_name)
except Exception as e:
# Credential might have been deleted between lookups (async timing)
# Return ID but empty name rather than failing
logger.warning(
"Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s",
credential_id,
provider,
model,
str(e),
exc_info=True,
)
return credential_id, credential_name
except Exception as e:
# Database query failed or other unexpected error
# Return empty rather than propagating error to telemetry emission
logger.warning(
"Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s",
tenant_id,
provider,
model,
str(e),
exc_info=True,
)
return None, ""
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, key: str) -> dict[str, Any]:
match key:
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
@@ -149,7 +294,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
}
case _:
raise KeyError(f"Unsupported tracing provider: {key}")
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map = OpsTraceProviderConfigMap()
@@ -314,6 +459,10 @@ class OpsTraceManager:
if app_id is None:
return None
# Handle storage_id format (tenant-{uuid}) - not a real app_id
if isinstance(app_id, str) and app_id.startswith("tenant-"):
return None
app: App | None = db.session.query(App).where(App.id == app_id).first()
if app is None:
@@ -466,8 +615,6 @@ class TraceTask:
@classmethod
def _get_workflow_run_repo(cls):
from repositories.factory import DifyAPIRepositoryFactory
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
@@ -478,6 +625,77 @@ class TraceTask:
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
@classmethod
def _calculate_workflow_token_split(
cls, session: "Session", workflow_run_id: str, tenant_id: str
) -> tuple[int, int]:
"""Sum prompt/completion tokens across all node executions for a workflow run.
Reads from the ``outputs`` column (where LLM nodes store ``usage.prompt_tokens``
and ``usage.completion_tokens``) rather than ``execution_metadata``, which only
carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading
large JSON blobs unnecessarily.
"""
import json
from models.workflow import WorkflowNodeExecutionModel
rows = (
session.execute(
select(WorkflowNodeExecutionModel.outputs).where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
)
.scalars()
.all()
)
total_prompt = 0
total_completion = 0
for raw in rows:
if not raw:
continue
try:
outputs = json.loads(raw) if isinstance(raw, str) else raw
except (ValueError, TypeError):
continue
if not isinstance(outputs, dict):
continue
usage = outputs.get("usage")
if not isinstance(usage, dict):
continue
prompt = usage.get("prompt_tokens")
if isinstance(prompt, (int, float)):
total_prompt += int(prompt)
completion = usage.get("completion_tokens")
if isinstance(completion, (int, float)):
total_completion += int(completion)
return (total_prompt, total_completion)
@classmethod
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
"""Extract user ID from metadata, prioritizing end_user over account.
Returns the actual user ID (end_user or account) who invoked the workflow,
regardless of invoke_from context.
"""
# Priority 1: End user (external users via API/WebApp)
if user_id := metadata.get("from_end_user_id"):
return f"end_user:{user_id}"
# Priority 2: Account user (internal users via console/debugger)
if user_id := metadata.get("from_account_id"):
return f"account:{user_id}"
# Priority 3: User (internal users via console/debugger)
if user_id := metadata.get("user_id"):
return f"user:{user_id}"
return "anonymous"
def __init__(
self,
trace_type: Any,
@@ -491,6 +709,7 @@ class TraceTask:
self.trace_type = trace_type
self.message_id = message_id
self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
self.workflow_total_tokens: int | None = workflow_execution.total_tokens if workflow_execution else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer
@@ -498,6 +717,8 @@ class TraceTask:
self.app_id = None
self.trace_id = None
self.kwargs = kwargs
if user_id is not None and "user_id" not in self.kwargs:
self.kwargs["user_id"] = user_id
external_trace_id = kwargs.get("external_trace_id")
if external_trace_id:
self.trace_id = external_trace_id
@@ -509,9 +730,12 @@ class TraceTask:
preprocess_map = {
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
workflow_run_id=self.workflow_run_id,
conversation_id=self.conversation_id,
user_id=self.user_id,
total_tokens_override=self.workflow_total_tokens,
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
@@ -527,6 +751,9 @@ class TraceTask:
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
),
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
}
return preprocess_map.get(self.trace_type, lambda: None)()
@@ -541,6 +768,7 @@ class TraceTask:
workflow_run_id: str | None,
conversation_id: str | None,
user_id: str | None,
total_tokens_override: int | None = None,
):
if not workflow_run_id:
return {}
@@ -560,7 +788,7 @@ class TraceTask:
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
total_tokens = total_tokens_override if total_tokens_override is not None else workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
@@ -581,8 +809,18 @@ class TraceTask:
Message.workflow_run_id == workflow_run_id,
)
message_id = session.scalar(message_data_stmt)
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
session, workflow_run_id=workflow_run_id, tenant_id=tenant_id
)
metadata = {
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
metadata: dict[str, Any] = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
@@ -595,8 +833,14 @@ class TraceTask:
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
parent_trace_context = self.kwargs.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
@@ -611,6 +855,8 @@ class TraceTask:
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
file_list=file_list,
query=query,
metadata=metadata,
@@ -618,10 +864,11 @@ class TraceTask:
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
invoked_by=self._get_user_id_from_metadata(metadata),
)
return workflow_trace_info
def message_trace(self, message_id: str | None):
def message_trace(self, message_id: str | None, **kwargs):
if not message_id:
return {}
message_data = get_message_data(message_id)
@@ -644,6 +891,19 @@ class TraceTask:
streaming_metrics = self._extract_streaming_metrics(message_data)
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
metadata = {
"conversation_id": message_data.conversation_id,
"ls_provider": message_data.model_provider,
@@ -655,7 +915,14 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"message_id": message_id,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
message_tokens = message_data.message_tokens
@@ -672,7 +939,9 @@ class TraceTask:
outputs=message_data.answer,
file_list=file_list,
start_time=created_at,
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
end_time=message_data.updated_at
if message_data.updated_at and message_data.updated_at > created_at
else created_at + timedelta(seconds=message_data.provider_response_latency),
metadata=metadata,
message_file_data=message_file_data,
conversation_mode=conversation_mode,
@@ -697,6 +966,8 @@ class TraceTask:
"preset_response": moderation_result.preset_response,
"query": moderation_result.query,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@@ -738,6 +1009,8 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@@ -777,6 +1050,52 @@ class TraceTask:
if not message_data:
return {}
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
doc_list = [doc.model_dump() for doc in documents] if documents else []
dataset_ids: set[str] = set()
for doc in doc_list:
doc_meta = doc.get("metadata") or {}
did = doc_meta.get("dataset_id")
if did:
dataset_ids.add(did)
embedding_models: dict[str, dict[str, str]] = {}
if dataset_ids:
with Session(db.engine) as session:
rows = session.execute(
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
Dataset.id.in_(list(dataset_ids))
)
).all()
for row in rows:
embedding_models[str(row[0])] = {
"embedding_model": row[1] or "",
"embedding_model_provider": row[2] or "",
}
# Extract rerank model info from retrieval_model kwargs
rerank_model_provider = ""
rerank_model_name = ""
if "retrieval_model" in kwargs:
retrieval_model = kwargs["retrieval_model"]
if isinstance(retrieval_model, dict):
reranking_model = retrieval_model.get("reranking_model")
if isinstance(reranking_model, dict):
rerank_model_provider = reranking_model.get("reranking_provider_name", "")
rerank_model_name = reranking_model.get("reranking_model_name", "")
metadata = {
"message_id": message_id,
"ls_provider": message_data.model_provider,
@@ -787,13 +1106,23 @@ class TraceTask:
"agent_based": message_data.agent_based,
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
"embedding_models": embedding_models,
"rerank_model_provider": rerank_model_provider,
"rerank_model_name": rerank_model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents] if documents else [],
documents=doc_list,
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
@@ -836,6 +1165,10 @@ class TraceTask:
"error": error,
"tool_parameters": tool_parameters,
}
if message_data.workflow_run_id:
metadata["workflow_run_id"] = message_data.workflow_run_id
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
file_url = ""
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
@@ -890,6 +1223,8 @@ class TraceTask:
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
generate_name_trace_info = GenerateNameTraceInfo(
trace_id=self.trace_id,
@@ -904,6 +1239,182 @@ class TraceTask:
return generate_name_trace_info
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
tenant_id = kwargs.get("tenant_id", "")
user_id = kwargs.get("user_id", "")
app_id = kwargs.get("app_id")
operation_type = kwargs.get("operation_type", "")
instruction = kwargs.get("instruction", "")
generated_output = kwargs.get("generated_output", "")
prompt_tokens = kwargs.get("prompt_tokens", 0)
completion_tokens = kwargs.get("completion_tokens", 0)
total_tokens = kwargs.get("total_tokens", 0)
model_provider = kwargs.get("model_provider", "")
model_name = kwargs.get("model_name", "")
latency = kwargs.get("latency", 0.0)
timer = kwargs.get("timer")
start_time = timer.get("start") if timer else None
end_time = timer.get("end") if timer else None
total_price = kwargs.get("total_price")
currency = kwargs.get("currency")
error = kwargs.get("error")
app_name = None
workspace_name = None
if app_id:
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
metadata = {
"tenant_id": tenant_id,
"user_id": user_id,
"app_id": app_id or "",
"app_name": app_name,
"workspace_name": workspace_name,
"operation_type": operation_type,
"model_provider": model_provider,
"model_name": model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
return PromptGenerationTraceInfo(
trace_id=self.trace_id,
inputs=instruction,
outputs=generated_output,
start_time=start_time,
end_time=end_time,
metadata=metadata,
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=operation_type,
instruction=instruction,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model_provider=model_provider,
model_name=model_name,
latency=latency,
total_price=total_price,
currency=currency,
error=error,
)
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
node_data: dict = kwargs.get("node_execution_data", {})
if not node_data:
return {}
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(
node_data.get("app_id"), node_data.get("tenant_id")
)
else:
app_name, workspace_name = "", ""
# Try tool credential lookup first
credential_id = node_data.get("credential_id")
if is_enterprise_telemetry_enabled():
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
if not credential_id:
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
tenant_id=node_data.get("tenant_id"),
provider=node_data.get("model_provider"),
model=node_data.get("model_name"),
model_type="llm",
)
if llm_cred_id:
credential_id = llm_cred_id
credential_name = llm_cred_name
else:
credential_name = ""
metadata: dict[str, Any] = {
"tenant_id": node_data.get("tenant_id"),
"app_id": node_data.get("app_id"),
"app_name": app_name,
"workspace_name": workspace_name,
"user_id": node_data.get("user_id"),
"invoke_from": node_data.get("invoke_from"),
"credential_id": credential_id,
"credential_name": credential_name,
"dataset_ids": node_data.get("dataset_ids"),
"dataset_names": node_data.get("dataset_names"),
"plugin_name": node_data.get("plugin_name"),
}
parent_trace_context = node_data.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
message_id: str | None = None
conversation_id = node_data.get("conversation_id")
workflow_execution_id = node_data.get("workflow_execution_id")
if conversation_id and workflow_execution_id and not parent_trace_context:
with Session(db.engine) as session:
msg_id = session.scalar(
select(Message.id).where(
Message.conversation_id == conversation_id,
Message.workflow_run_id == workflow_execution_id,
)
)
if msg_id:
message_id = str(msg_id)
metadata["message_id"] = message_id
if conversation_id:
metadata["conversation_id"] = conversation_id
return WorkflowNodeTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
start_time=node_data.get("created_at"),
end_time=node_data.get("finished_at"),
metadata=metadata,
workflow_id=node_data.get("workflow_id", ""),
workflow_run_id=node_data.get("workflow_execution_id", ""),
tenant_id=node_data.get("tenant_id", ""),
node_execution_id=node_data.get("node_execution_id", ""),
node_id=node_data.get("node_id", ""),
node_type=node_data.get("node_type", ""),
title=node_data.get("title", ""),
status=node_data.get("status", ""),
error=node_data.get("error"),
elapsed_time=node_data.get("elapsed_time", 0.0),
index=node_data.get("index", 0),
predecessor_node_id=node_data.get("predecessor_node_id"),
total_tokens=node_data.get("total_tokens", 0),
total_price=node_data.get("total_price", 0.0),
currency=node_data.get("currency"),
model_provider=node_data.get("model_provider"),
model_name=node_data.get("model_name"),
prompt_tokens=node_data.get("prompt_tokens"),
completion_tokens=node_data.get("completion_tokens"),
tool_name=node_data.get("tool_name"),
iteration_id=node_data.get("iteration_id"),
iteration_index=node_data.get("iteration_index"),
loop_id=node_data.get("loop_id"),
loop_index=node_data.get("loop_index"),
parallel_id=node_data.get("parallel_id"),
node_inputs=node_data.get("node_inputs"),
node_outputs=node_data.get("node_outputs"),
process_data=node_data.get("process_data"),
invoked_by=self._get_user_id_from_metadata(metadata),
)
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
node_trace = self.node_execution_trace(**kwargs)
if not isinstance(node_trace, WorkflowNodeTraceInfo):
return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict:
if not message_data.message_metadata:
return {}
@@ -937,13 +1448,17 @@ class TraceQueueManager:
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
self.flask_app = current_app._get_current_object() # type: ignore
from core.telemetry.gateway import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
if trace_manager_timer is None:
self.start_timer()
def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer, trace_manager_queue
try:
if self.trace_instance:
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception:
@@ -979,20 +1494,27 @@ class TraceQueueManager:
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
if task.app_id is None:
continue
storage_id = task.app_id
if storage_id is None:
tenant_id = task.kwargs.get("tenant_id")
if tenant_id:
storage_id = f"tenant-{tenant_id}"
else:
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
app_id=task.app_id,
app_id=storage_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
"app_id": storage_id,
}
process_trace_tasks.delay(file_info) # type: ignore

View File

@@ -13,6 +13,7 @@ from core.plugin.endpoint.exc import EndpointSetupFailedError
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
from core.plugin.impl.exc import (
PluginDaemonBadRequestError,
PluginDaemonClientSideError,
PluginDaemonInternalServerError,
PluginDaemonNotFoundError,
PluginDaemonUnauthorizedError,
@@ -235,7 +236,10 @@ class BasePluginClient:
response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
raise e
if e.response.status_code < 500:
raise PluginDaemonClientSideError(description=str(e))
else:
raise PluginDaemonInternalServerError(description=str(e))
except Exception as e:
msg = f"Failed to request plugin daemon, url: {path}"
logger.exception("Failed to request plugin daemon, url: %s", path)

View File

@@ -0,0 +1,43 @@
"""Telemetry facade.
Thin public API for emitting telemetry events. All routing logic
lives in ``core.telemetry.gateway`` which is shared by both CE and EE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
from core.telemetry.gateway import emit as gateway_emit
from core.telemetry.gateway import get_trace_task_to_case
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
"""Emit a telemetry event.
Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a
``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``.
"""
case = get_trace_task_to_case().get(event.name)
if case is None:
return
context: dict[str, object] = {
"tenant_id": event.context.tenant_id,
"user_id": event.context.user_id,
"app_id": event.context.app_id,
}
gateway_emit(case, context, event.payload, trace_manager)
__all__ = [
"TelemetryContext",
"TelemetryEvent",
"TraceTaskName",
"emit",
]

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.ops.entities.trace_entity import TraceTaskName
@dataclass(frozen=True)
class TelemetryContext:
tenant_id: str | None = None
user_id: str | None = None
app_id: str | None = None
@dataclass(frozen=True)
class TelemetryEvent:
name: TraceTaskName
context: TelemetryContext
payload: dict[str, Any]

View File

@@ -0,0 +1,239 @@
"""Telemetry gateway — single routing layer for all editions.
Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either
the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only
metric/log Celery queue.
This module lives in ``core/`` so both CE and EE share one routing table
and one ``emit()`` entry point. No separate enterprise gateway module is
needed — enterprise-specific dispatch (Celery task, payload offloading)
is handled here behind lazy imports that no-op in CE.
"""
from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any
from core.ops.entities.trace_entity import TraceTaskName
from enterprise.telemetry.contracts import CaseRoute, SignalType
from extensions.ext_storage import storage
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from enterprise.telemetry.contracts import TelemetryCase
logger = logging.getLogger(__name__)
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
# ---------------------------------------------------------------------------
# Routing table — authoritative mapping for all editions
# ---------------------------------------------------------------------------
_case_to_trace_task: dict[TelemetryCase, TraceTaskName] | None = None
_case_routing: dict[TelemetryCase, CaseRoute] | None = None
def _get_case_to_trace_task() -> dict[TelemetryCase, TraceTaskName]:
global _case_to_trace_task
if _case_to_trace_task is None:
from enterprise.telemetry.contracts import TelemetryCase
_case_to_trace_task = {
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE,
TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE,
TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE,
TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE,
TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE,
}
return _case_to_trace_task
def get_trace_task_to_case() -> dict[TraceTaskName, TelemetryCase]:
"""Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task)."""
return {v: k for k, v in _get_case_to_trace_task().items()}
def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
global _case_routing
if _case_routing is None:
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase
_case_routing = {
# TRACE — CE-eligible (flow in both CE and EE)
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
# TRACE — enterprise-only
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
# METRIC_LOG — enterprise-only (signal-driven, not trace)
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
}
return _case_routing
def __getattr__(name: str) -> dict:
"""Lazy module-level access to routing tables."""
if name == "CASE_ROUTING":
return _get_case_routing()
if name == "CASE_TO_TRACE_TASK":
return _get_case_to_trace_task()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def is_enterprise_telemetry_enabled() -> bool:
try:
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
return is_enterprise_telemetry_enabled()
except Exception:
return False
def _handle_payload_sizing(
payload: dict[str, Any],
tenant_id: str,
event_id: str,
) -> tuple[dict[str, Any], str | None]:
"""Inline or offload payload based on size.
Returns ``(payload_for_envelope, storage_key | None)``. Payloads
exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object
storage and replaced with an empty dict in the envelope.
"""
try:
payload_json = json.dumps(payload)
payload_size = len(payload_json.encode("utf-8"))
except (TypeError, ValueError):
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
return payload, None
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
return payload, None
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
try:
storage.save(storage_key, payload_json.encode("utf-8"))
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
return {}, storage_key
except Exception:
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
return payload, None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def emit(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None = None,
) -> None:
"""Route a telemetry event to the correct pipeline.
TRACE events are enqueued into ``TraceQueueManager`` (works in both CE
and EE). Enterprise-only traces are silently dropped when EE is
disabled.
METRIC_LOG events are dispatched to the enterprise Celery queue;
silently dropped when enterprise telemetry is unavailable.
"""
route = _get_case_routing().get(case)
if route is None:
logger.warning("Unknown telemetry case: %s, dropping event", case)
return
if not route.ce_eligible and not is_enterprise_telemetry_enabled():
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
return
if route.signal_type == SignalType.TRACE:
_emit_trace(case, context, payload, trace_manager)
else:
_emit_metric_log(case, context, payload)
def _emit_trace(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None,
) -> None:
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
from core.ops.ops_trace_manager import TraceTask
trace_task_name = _get_case_to_trace_task().get(case)
if trace_task_name is None:
logger.warning("No TraceTaskName mapping for case: %s", case)
return
queue_manager = trace_manager or LocalTraceQueueManager(
app_id=context.get("app_id"),
user_id=context.get("user_id"),
)
queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload))
logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id"))
def _emit_metric_log(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
) -> None:
"""Build envelope and dispatch to enterprise Celery queue.
No-ops when the enterprise telemetry task is not importable (CE mode).
"""
try:
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
except ImportError:
logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case)
return
tenant_id = context.get("tenant_id") or ""
event_id = str(uuid.uuid4())
payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id)
from enterprise.telemetry.contracts import TelemetryEnvelope
envelope = TelemetryEnvelope(
case=case,
tenant_id=tenant_id,
event_id=event_id,
payload=payload_for_envelope,
metadata={"payload_ref": payload_ref} if payload_ref else None,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())
logger.debug(
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
case,
tenant_id,
event_id,
)

View File

@@ -0,0 +1,525 @@
# Dify Enterprise Telemetry Data Dictionary
Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md).
## Resource Attributes
Attached to every signal (Span, Metric, Log).
| Attribute | Type | Example |
|-----------|------|---------|
| `service.name` | string | `dify` |
| `host.name` | string | `dify-api-7f8b` |
## Traces (Spans)
### `dify.workflow.run`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.trace_id` | string | Business trace ID (Workflow Run ID) |
| `dify.tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.workflow.id` | string | Workflow definition ID |
| `dify.workflow.run_id` | string | Unique ID for this run |
| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. |
| `dify.workflow.error` | string | Error message if failed |
| `dify.workflow.elapsed_time` | float | Total execution time (seconds) |
| `dify.invoke_from` | string | `api`, `webapp`, `debug` |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.message.id` | string | Message ID (optional) |
| `dify.invoked_by` | string | User ID who triggered the run |
| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) |
| `gen_ai.user.id` | string | End-user identifier (optional) |
| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) |
| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) |
| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) |
| `dify.parent.app.id` | string | Parent app ID (optional) |
### `dify.node.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.trace_id` | string | Business trace ID |
| `dify.tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.workflow.id` | string | Workflow definition ID |
| `dify.workflow.run_id` | string | Workflow Run ID |
| `dify.message.id` | string | Message ID (optional) |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.node.execution_id` | string | Unique node execution ID |
| `dify.node.id` | string | Node ID in workflow graph |
| `dify.node.type` | string | Node type (see appendix) |
| `dify.node.title` | string | Display title |
| `dify.node.status` | string | `succeeded`, `failed` |
| `dify.node.error` | string | Error message if failed |
| `dify.node.elapsed_time` | float | Execution time (seconds) |
| `dify.node.index` | int | Execution order index |
| `dify.node.predecessor_node_id` | string | Triggering node ID |
| `dify.node.iteration_id` | string | Iteration ID (optional) |
| `dify.node.loop_id` | string | Loop ID (optional) |
| `dify.node.parallel_id` | string | Parallel branch ID (optional) |
| `dify.node.invoked_by` | string | User ID who triggered execution |
| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) |
| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) |
| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) |
| `gen_ai.request.model` | string | LLM model name (LLM nodes only) |
| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) |
| `gen_ai.user.id` | string | End-user identifier (optional) |
### `dify.node.execution.draft`
Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs.
## Counters
All counters are cumulative and emitted at 100% accuracy.
### Token Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.tokens.total` | `{token}` | Total tokens consumed |
| `dify.tokens.input` | `{token}` | Input (prompt) tokens |
| `dify.tokens.output` | `{token}` | Output (completion) tokens |
**Labels:**
- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution)
⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting.
#### Token Hierarchy & Query Patterns
Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting:
```
App-level total
├── workflow ← sum of all node_execution tokens (DO NOT add both)
│ └── node_execution ← per-node breakdown
├── message ← independent (non-workflow chat apps only)
├── rule_generate ← independent helper LLM call
├── code_generate ← independent helper LLM call
├── structured_output ← independent helper LLM call
└── instruction_modify← independent helper LLM call
```
**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both.
**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`.
App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries.
**Common queries** (PromQL):
```promql
# ── Totals ──────────────────────────────────────────────────
# App-level total (exclude node_execution to avoid double-counting)
sum by (app_id) (dify_tokens_total{operation_type!="node_execution"})
# Single app total
sum (dify_tokens_total{app_id="<app_id>", operation_type!="node_execution"})
# Per-tenant totals
sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"})
# ── Drill-down ──────────────────────────────────────────────
# Workflow-level tokens for an app
sum (dify_tokens_total{app_id="<app_id>", operation_type="workflow"})
# Node-level breakdown within an app
sum by (node_type) (dify_tokens_total{app_id="<app_id>", operation_type="node_execution"})
# Model breakdown for an app
sum by (model_provider, model_name) (dify_tokens_total{app_id="<app_id>"})
# Input vs output per model
sum by (model_name) (dify_tokens_input_total{app_id="<app_id>"})
sum by (model_name) (dify_tokens_output_total{app_id="<app_id>"})
# ── Rates ───────────────────────────────────────────────────
# Token consumption rate (per hour)
sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
# Per-app consumption rate
sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
```
**Finding `app_id` from app name** (trace query — Tempo / Jaeger):
```
{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id)
```
### Request Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.requests.total` | `{request}` | Total operations count |
**Labels by type:**
| `type` | Additional Labels |
|--------|-------------------|
| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` |
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` |
| `tool` | `tenant_id`, `app_id`, `tool_name` |
| `moderation` | `tenant_id`, `app_id` |
| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dataset_retrieval` | `tenant_id`, `app_id` |
| `generate_name` | `tenant_id`, `app_id` |
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` |
### Error Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.errors.total` | `{error}` | Total failed operations |
**Labels by type:**
| `type` | Additional Labels |
|--------|-------------------|
| `workflow` | `tenant_id`, `app_id` |
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `tool` | `tenant_id`, `app_id`, `tool_name` |
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
### Other Counters
| Metric | Unit | Labels |
|--------|------|--------|
| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` |
| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` |
| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` |
| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` |
| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` |
## Histograms
| Metric | Unit | Labels |
|--------|------|--------|
| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` |
| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` |
| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` |
| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
## Structured Logs
### Span Companion Logs
Logs that accompany spans. Signal type: `span_detail`
#### `dify.workflow.run` Companion Log
**Common attributes:** All span attributes (see Traces section) plus:
| Additional Attribute | Type | Always Present | Description |
|---------------------|------|----------------|-------------|
| `dify.app.name` | string | No | Application display name |
| `dify.workspace.name` | string | No | Workspace display name |
| `dify.workflow.version` | string | Yes | Workflow definition version |
| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) |
| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) |
| `dify.workflow.query` | string | No | User query text (content-gated) |
**Event attributes:**
- `dify.event.name`: `"dify.workflow.run"`
- `dify.event.signal`: `"span_detail"`
- `trace_id`, `span_id`, `tenant_id`, `user_id`
#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs
**Common attributes:** All span attributes (see Traces section) plus:
| Additional Attribute | Type | Always Present | Description |
|---------------------|------|----------------|-------------|
| `dify.app.name` | string | No | Application display name |
| `dify.workspace.name` | string | No | Workspace display name |
| `dify.invoke_from` | string | No | Invocation source |
| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) |
| `dify.node.total_price` | float | No | Cost (LLM nodes only) |
| `dify.node.currency` | string | No | Currency code (LLM nodes only) |
| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) |
| `dify.node.loop_index` | int | No | Loop index (loop nodes) |
| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) |
| `dify.credential.name` | string | No | Credential name (plugin nodes) |
| `dify.credential.id` | string | No | Credential ID (plugin nodes) |
| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) |
| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) |
| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) |
| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) |
| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) |
**Event attributes:**
- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"`
- `dify.event.signal`: `"span_detail"`
- `trace_id`, `span_id`, `tenant_id`, `user_id`
### Standalone Logs
Logs without structural spans. Signal type: `metric_only`
#### `dify.message.run`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.message.run"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID (32-char hex) |
| `span_id` | string | OTEL span ID (16-char hex) |
| `tenant_id` | string | Tenant identifier |
| `user_id` | string | User identifier (optional) |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.workflow.run_id` | string | Workflow run ID (optional) |
| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` |
| `gen_ai.provider.name` | string | LLM provider |
| `gen_ai.request.model` | string | LLM model |
| `gen_ai.usage.input_tokens` | int | Input tokens |
| `gen_ai.usage.output_tokens` | int | Output tokens |
| `gen_ai.usage.total_tokens` | int | Total tokens |
| `dify.message.status` | string | `succeeded`, `failed` |
| `dify.message.error` | string | Error message (if failed) |
| `dify.message.duration` | float | Duration (seconds) |
| `dify.message.time_to_first_token` | float | TTFT (seconds) |
| `dify.message.inputs` | string/JSON | Inputs (content-gated) |
| `dify.message.outputs` | string/JSON | Outputs (content-gated) |
#### `dify.tool.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.tool.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.tool.name` | string | Tool name |
| `dify.tool.duration` | float | Duration (seconds) |
| `dify.tool.status` | string | `succeeded`, `failed` |
| `dify.tool.error` | string | Error message (if failed) |
| `dify.tool.inputs` | string/JSON | Inputs (content-gated) |
| `dify.tool.outputs` | string/JSON | Outputs (content-gated) |
| `dify.tool.parameters` | string/JSON | Parameters (content-gated) |
| `dify.tool.config` | string/JSON | Configuration (content-gated) |
#### `dify.moderation.check`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.moderation.check"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.moderation.type` | string | `input`, `output` |
| `dify.moderation.action` | string | `pass`, `block`, `flag` |
| `dify.moderation.flagged` | boolean | Whether flagged |
| `dify.moderation.categories` | JSON array | Flagged categories |
| `dify.moderation.query` | string | Content (content-gated) |
#### `dify.suggested_question.generation`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.suggested_question.generation"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.suggested_question.count` | int | Number of questions |
| `dify.suggested_question.duration` | float | Duration (seconds) |
| `dify.suggested_question.status` | string | `succeeded`, `failed` |
| `dify.suggested_question.error` | string | Error message (if failed) |
| `dify.suggested_question.questions` | JSON array | Questions (content-gated) |
#### `dify.dataset.retrieval`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.dataset.retrieval"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.dataset.id` | string | Dataset identifier |
| `dify.dataset.name` | string | Dataset name |
| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) |
| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) |
| `dify.retrieval.rerank_provider` | string | Rerank model provider |
| `dify.retrieval.rerank_model` | string | Rerank model name |
| `dify.retrieval.query` | string | Search query (content-gated) |
| `dify.retrieval.document_count` | int | Documents retrieved |
| `dify.retrieval.duration` | float | Duration (seconds) |
| `dify.retrieval.status` | string | `succeeded`, `failed` |
| `dify.retrieval.error` | string | Error message (if failed) |
| `dify.dataset.documents` | JSON array | Documents (content-gated) |
#### `dify.generate_name.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.generate_name.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.conversation.id` | string | Conversation identifier |
| `dify.generate_name.duration` | float | Duration (seconds) |
| `dify.generate_name.status` | string | `succeeded`, `failed` |
| `dify.generate_name.error` | string | Error message (if failed) |
| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) |
| `dify.generate_name.outputs` | string | Generated name (content-gated) |
#### `dify.prompt_generation.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.prompt_generation.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) |
| `gen_ai.provider.name` | string | LLM provider |
| `gen_ai.request.model` | string | LLM model |
| `gen_ai.usage.input_tokens` | int | Input tokens |
| `gen_ai.usage.output_tokens` | int | Output tokens |
| `gen_ai.usage.total_tokens` | int | Total tokens |
| `dify.prompt_generation.duration` | float | Duration (seconds) |
| `dify.prompt_generation.status` | string | `succeeded`, `failed` |
| `dify.prompt_generation.error` | string | Error message (if failed) |
| `dify.prompt_generation.instruction` | string | Instruction (content-gated) |
| `dify.prompt_generation.output` | string/JSON | Output (content-gated) |
#### `dify.app.created`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.created"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` |
| `dify.app.created_at` | string | Timestamp (ISO 8601) |
#### `dify.app.updated`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.updated"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.updated_at` | string | Timestamp (ISO 8601) |
#### `dify.app.deleted`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.deleted"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.deleted_at` | string | Timestamp (ISO 8601) |
#### `dify.feedback.created`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.feedback.created"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.feedback.rating` | string | `like`, `dislike`, `null` |
| `dify.feedback.content` | string | Feedback text (content-gated) |
| `dify.feedback.created_at` | string | Timestamp (ISO 8601) |
#### `dify.telemetry.rehydration_failed`
Diagnostic event for telemetry system health monitoring.
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.telemetry.error` | string | Error message |
| `dify.telemetry.payload_type` | string | Payload type (see appendix) |
| `dify.telemetry.correlation_id` | string | Correlation ID |
## Content-Gated Attributes
When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`).
| Attribute | Signal |
|-----------|--------|
| `dify.workflow.inputs` | `dify.workflow.run` |
| `dify.workflow.outputs` | `dify.workflow.run` |
| `dify.workflow.query` | `dify.workflow.run` |
| `dify.node.inputs` | `dify.node.execution` |
| `dify.node.outputs` | `dify.node.execution` |
| `dify.node.process_data` | `dify.node.execution` |
| `dify.message.inputs` | `dify.message.run` |
| `dify.message.outputs` | `dify.message.run` |
| `dify.tool.inputs` | `dify.tool.execution` |
| `dify.tool.outputs` | `dify.tool.execution` |
| `dify.tool.parameters` | `dify.tool.execution` |
| `dify.tool.config` | `dify.tool.execution` |
| `dify.moderation.query` | `dify.moderation.check` |
| `dify.suggested_question.questions` | `dify.suggested_question.generation` |
| `dify.retrieval.query` | `dify.dataset.retrieval` |
| `dify.dataset.documents` | `dify.dataset.retrieval` |
| `dify.generate_name.inputs` | `dify.generate_name.execution` |
| `dify.generate_name.outputs` | `dify.generate_name.execution` |
| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` |
| `dify.prompt_generation.output` | `dify.prompt_generation.execution` |
| `dify.feedback.content` | `dify.feedback.created` |
## Appendix
### Operation Types
- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify`
### Node Types
- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input`
### Workflow Statuses
- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused`
### Payload Types
- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback`
### Null Value Behavior
**Spans:** Attributes with `null` values are omitted.
**Logs:** Attributes with `null` values appear as `null` in JSON.
**Content-Gated:** Replaced with reference strings, not set to `null`.

View File

@@ -0,0 +1,121 @@
# Dify Enterprise Telemetry
This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb.
## Overview
Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage.
- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes).
- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`.
- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking.
### Signal Architecture
```mermaid
graph TD
A[Workflow Run] -->|Span| B(dify.workflow.run)
A -->|Log| C(dify.workflow.run detail)
B ---|trace_id| C
D[Node Execution] -->|Span| E(dify.node.execution)
D -->|Log| F(dify.node.execution detail)
E ---|span_id| F
G[Message/Tool/etc] -->|Log| H(dify.* event)
G -->|Metric| I(dify.* counter/histogram)
```
## Configuration
The Enterprise OTEL exporter is configured via environment variables.
| Variable | Description | Default |
|----------|-------------|---------|
| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` |
| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` |
| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - |
| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - |
| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` |
| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - |
| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `false` |
| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` |
| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` |
## Correlation Model
Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks.
### ID Generation Rules
- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))`
- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)`
### Scenario A: Simple Workflow
A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`).
```
trace_id = UUID(workflow_run_id)
├── [root span] dify.workflow.run (span_id = hash(workflow_run_id))
│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1))
│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2))
│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3))
```
### Scenario B: Nested Sub-Workflow
A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id.
```
trace_id = UUID(outer_workflow_run_id) ← shared across both workflows
├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id))
│ ├── dify.node.execution - "Start Node"
│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow)
│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id))
│ │ ├── dify.node.execution - "Inner Start"
│ │ └── dify.node.execution - "Inner End"
│ └── dify.node.execution - "End Node"
```
**Key attributes for nested workflows:**
- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id`
- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id`
- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id`
- Inner workflow's `dify.parent.app.id` = outer `app_id`
### Scenario C: Draft Node Execution
A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root.
```
trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow
└── dify.node.execution.draft (span_id = hash(node_execution_id))
```
**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace.
## Content Gating
When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector.
**Reference String Format:**
```
ref:{id_type}={uuid}
```
**Examples:**
```
ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000
ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001
ref:message_id=770e8400-e29b-41d4-a716-446655440002
```
To retrieve actual content when gating is enabled, query the Dify database using the provided UUID.
## Reference
For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md).

View File

@@ -0,0 +1,73 @@
"""Telemetry gateway contracts and data structures.
This module defines the envelope format for telemetry events and the routing
configuration that determines how each event type is processed.
"""
from __future__ import annotations
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict
class TelemetryCase(StrEnum):
"""Enumeration of all known telemetry event cases."""
WORKFLOW_RUN = "workflow_run"
NODE_EXECUTION = "node_execution"
DRAFT_NODE_EXECUTION = "draft_node_execution"
MESSAGE_RUN = "message_run"
TOOL_EXECUTION = "tool_execution"
MODERATION_CHECK = "moderation_check"
SUGGESTED_QUESTION = "suggested_question"
DATASET_RETRIEVAL = "dataset_retrieval"
GENERATE_NAME = "generate_name"
PROMPT_GENERATION = "prompt_generation"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
FEEDBACK_CREATED = "feedback_created"
class SignalType(StrEnum):
"""Signal routing type for telemetry cases."""
TRACE = "trace"
METRIC_LOG = "metric_log"
class CaseRoute(BaseModel):
"""Routing configuration for a telemetry case.
Attributes:
signal_type: The type of signal (trace or metric_log).
ce_eligible: Whether this case is eligible for community edition tracing.
"""
signal_type: SignalType
ce_eligible: bool
class TelemetryEnvelope(BaseModel):
"""Envelope for telemetry events.
Attributes:
case: The telemetry case type.
tenant_id: The tenant identifier.
event_id: Unique event identifier for deduplication.
payload: The main event payload (inline for small payloads,
empty when offloaded to storage via ``payload_ref``).
metadata: Optional metadata dictionary. When the gateway
offloads a large payload to object storage, this contains
``{"payload_ref": "<storage_key>"}``.
"""
model_config = ConfigDict(extra="forbid", use_enum_values=False)
case: TelemetryCase
tenant_id: str
event_id: str
payload: dict[str, Any]
metadata: dict[str, Any] | None = None

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from graphon.enums import WorkflowNodeExecutionMetadataKey
from models.workflow import WorkflowNodeExecutionModel
def enqueue_draft_node_execution_trace(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
user_id: str,
) -> None:
node_data = _build_node_execution_data(
execution=execution,
outputs=outputs,
workflow_execution_id=workflow_execution_id,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id=execution.tenant_id,
user_id=user_id,
app_id=execution.app_id,
),
payload={"node_execution_data": node_data},
)
)
def _build_node_execution_data(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
) -> dict[str, Any]:
metadata = execution.execution_metadata_dict
node_outputs = outputs if outputs is not None else execution.outputs_dict
execution_id = workflow_execution_id or execution.workflow_run_id or execution.id
process_data = execution.process_data_dict or {}
# Extract token breakdown from outputs.usage (set by LLM node)
usage: Mapping[str, Any] = {}
if isinstance(node_outputs, Mapping):
raw_usage = node_outputs.get("usage")
if isinstance(raw_usage, Mapping):
usage = raw_usage
return {
"workflow_id": execution.workflow_id,
"workflow_execution_id": execution_id,
"tenant_id": execution.tenant_id,
"app_id": execution.app_id,
"node_execution_id": execution.id,
"node_id": execution.node_id,
"node_type": execution.node_type,
"title": execution.title,
"status": execution.status,
"error": execution.error,
"elapsed_time": execution.elapsed_time,
"index": execution.index,
"predecessor_node_id": execution.predecessor_node_id,
"created_at": execution.created_at,
"finished_at": execution.finished_at,
"total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
"total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
"currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
"model_provider": process_data.get("model_provider"),
"model_name": process_data.get("model_name"),
"prompt_tokens": usage.get("prompt_tokens"),
"completion_tokens": usage.get("completion_tokens"),
"tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
else None,
"iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
"iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
"loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
"loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
"parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
"node_inputs": execution.inputs_dict,
"node_outputs": node_outputs,
"process_data": execution.process_data_dict,
}

View File

@@ -0,0 +1,966 @@
"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass.
Invoked directly in the Celery task, not through OpsTraceManager dispatch.
Only requires a matching ``trace(trace_info)`` method signature.
Signal strategy:
- **Traces (spans)**: workflow run, node execution, draft node execution only.
- **Metrics + structured logs**: all other event types.
Token metric labels (unified structure):
All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the
same label set for consistent filtering and aggregation:
- tenant_id: Tenant identifier
- app_id: Application identifier
- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.)
- model_provider: LLM provider name (empty string if not applicable)
- model_name: LLM model name (empty string if not applicable)
- node_type: Workflow node type (empty string if not node_execution)
This unified structure allows filtering by operation_type to separate:
- Workflow-level aggregates (operation_type=workflow)
- Individual node executions (operation_type=node_execution)
- Direct message calls (operation_type=message)
- Prompt generation operations (operation_type=rule_generate, code_generate, etc.)
Without this, tokens are double-counted when querying totals (workflow totals include
node totals, since workflow.total_tokens is the sum of all node tokens).
"""
from __future__ import annotations
import json
import logging
from typing import Any, cast
from opentelemetry.util.types import AttributeValue
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
OperationType,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from enterprise.telemetry.entities import (
EnterpriseTelemetryCounter,
EnterpriseTelemetryEvent,
EnterpriseTelemetryHistogram,
EnterpriseTelemetrySpan,
TokenMetricLabels,
)
from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log
logger = logging.getLogger(__name__)
class EnterpriseOtelTrace:
"""Duck-typed enterprise trace handler.
``*_trace`` methods emit spans (workflow/node only) or structured logs
(all other events), plus metrics at 100 % accuracy.
"""
def __init__(self) -> None:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if exporter is None:
raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized")
self._exporter = exporter
def trace(self, trace_info: BaseTraceInfo) -> None:
if isinstance(trace_info, WorkflowTraceInfo):
self._workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self._message_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self._tool_trace(trace_info)
elif isinstance(trace_info, DraftNodeExecutionTrace):
self._draft_node_execution_trace(trace_info)
elif isinstance(trace_info, WorkflowNodeTraceInfo):
self._node_execution_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
self._moderation_trace(trace_info)
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self._suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self._dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
self._generate_name_trace(trace_info)
elif isinstance(trace_info, PromptGenerationTraceInfo):
self._prompt_generation_trace(trace_info)
else:
raise AssertionError("this statment should be unreachable")
def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
metadata = self._metadata(trace_info)
tenant_id, app_id, user_id = self._context_ids(trace_info, metadata)
return {
"dify.trace_id": trace_info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"dify.message.id": trace_info.message_id,
}
def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
return trace_info.metadata
def _context_ids(
self,
trace_info: BaseTraceInfo,
metadata: dict[str, Any],
) -> tuple[str | None, str | None, str | None]:
tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id")
app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id")
user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id")
return tenant_id, app_id, user_id
def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]:
return dict(values)
def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None:
if isinstance(value, str):
return value
if isinstance(value, dict):
return cast(dict[str, Any], value)
if isinstance(value, list):
items: list[object] = []
for item in cast(list[object], value):
items.append(item)
return items
return None
def _content_or_ref(self, value: Any, ref: str) -> Any:
if self._exporter.include_content:
return self._maybe_json(value)
return ref
def _maybe_json(self, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, default=str)
except (TypeError, ValueError):
return str(value)
# ------------------------------------------------------------------
# SPAN-emitting handlers (workflow, node execution, draft node)
# ------------------------------------------------------------------
def _workflow_trace(self, info: WorkflowTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.workflow.status": info.workflow_run_status,
"dify.workflow.error": info.error,
"dify.workflow.elapsed_time": info.workflow_run_elapsed_time,
"dify.invoke_from": metadata.get("triggered_from"),
"dify.conversation.id": info.conversation_id,
"dify.message.id": info.message_id,
"dify.invoked_by": info.invoked_by,
"gen_ai.usage.total_tokens": info.total_tokens,
"gen_ai.user.id": user_id,
}
trace_correlation_override, parent_span_id_source = info.resolved_parent_context
parent_ctx = metadata.get("parent_trace_context")
if isinstance(parent_ctx, dict):
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id")
span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id")
span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id")
span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id")
self._exporter.export_span(
EnterpriseTelemetrySpan.WORKFLOW_RUN,
span_attrs,
correlation_id=info.workflow_run_id,
span_id_source=info.workflow_run_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
parent_span_id_source=parent_span_id_source,
)
# -- Companion log: ALL attrs (span + detail) for full picture --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.workflow.version": info.workflow_run_version,
}
)
ref = f"ref:workflow_run_id={info.workflow_run_id}"
log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref)
log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref)
log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref)
emit_telemetry_log(
event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.workflow_run_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.WORKFLOW,
model_provider="",
model_name="",
node_type="",
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
invoke_from = metadata.get("triggered_from", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="workflow",
status=info.workflow_run_status,
invoke_from=invoke_from,
),
)
# Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults
# to 0 in the DB and can be stale if the Celery write races with the trace task.
# start_time = workflow_run.created_at, end_time = workflow_run.finished_at.
if info.start_time and info.end_time:
workflow_duration = (info.end_time - info.start_time).total_seconds()
elif info.workflow_run_elapsed_time:
workflow_duration = float(info.workflow_run_elapsed_time)
else:
workflow_duration = 0.0
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.WORKFLOW_DURATION,
workflow_duration,
self._labels(
**labels,
status=info.workflow_run_status,
),
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="workflow",
),
)
def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None:
self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node")
def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None:
self._emit_node_execution_trace(
info,
EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION,
"draft_node",
correlation_id_override=info.node_execution_id,
trace_correlation_override_param=info.workflow_run_id,
)
def _emit_node_execution_trace(
self,
info: WorkflowNodeTraceInfo,
span_name: EnterpriseTelemetrySpan,
request_type: str,
correlation_id_override: str | None = None,
trace_correlation_override_param: str | None = None,
) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.message.id": info.message_id,
"dify.conversation.id": metadata.get("conversation_id"),
"dify.node.execution_id": info.node_execution_id,
"dify.node.id": info.node_id,
"dify.node.type": info.node_type,
"dify.node.title": info.title,
"dify.node.status": info.status,
"dify.node.error": info.error,
"dify.node.elapsed_time": info.elapsed_time,
"dify.node.index": info.index,
"dify.node.predecessor_node_id": info.predecessor_node_id,
"dify.node.iteration_id": info.iteration_id,
"dify.node.loop_id": info.loop_id,
"dify.node.parallel_id": info.parallel_id,
"dify.node.invoked_by": info.invoked_by,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"gen_ai.request.model": info.model_name,
"gen_ai.provider.name": info.model_provider,
"gen_ai.user.id": user_id,
}
resolved_override, _ = info.resolved_parent_context
trace_correlation_override = trace_correlation_override_param or resolved_override
effective_correlation_id = correlation_id_override or info.workflow_run_id
self._exporter.export_span(
span_name,
span_attrs,
correlation_id=effective_correlation_id,
span_id_source=info.node_execution_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
)
# -- Companion log: ALL attrs (span + detail) --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.invoke_from": metadata.get("invoke_from"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.node.total_price": info.total_price,
"dify.node.currency": info.currency,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.tool.name": info.tool_name,
"dify.node.iteration_index": info.iteration_index,
"dify.node.loop_index": info.loop_index,
"dify.plugin.name": metadata.get("plugin_name"),
"dify.credential.name": metadata.get("credential_name"),
"dify.credential.id": metadata.get("credential_id"),
"dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")),
"dify.dataset.names": self._maybe_json(metadata.get("dataset_names")),
}
)
ref = f"ref:node_execution_id={info.node_execution_id}"
log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref)
log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref)
log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref)
emit_telemetry_log(
event_name=span_name.value,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
node_type=info.node_type,
model_provider=info.model_provider or "",
)
if info.total_tokens:
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.NODE_EXECUTION,
model_provider=info.model_provider or "",
model_name=info.model_name or "",
node_type=info.node_type,
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels
)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type=request_type,
status=info.status,
model_name=info.model_name or "",
),
)
duration_labels = dict(labels)
duration_labels["model_name"] = info.model_name or ""
plugin_name = metadata.get("plugin_name")
if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}:
duration_labels["plugin_name"] = plugin_name
self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type=request_type,
model_name=info.model_name or "",
),
)
# ------------------------------------------------------------------
# METRIC-ONLY handlers (structured log + counters/histograms)
# ------------------------------------------------------------------
def _message_trace(self, info: MessageTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.invoke_from": metadata.get("from_source"),
"dify.conversation.id": metadata.get("conversation_id"),
"dify.conversation.mode": info.conversation_mode,
"gen_ai.provider.name": metadata.get("ls_provider"),
"gen_ai.request.model": metadata.get("ls_model_name"),
"gen_ai.usage.input_tokens": info.message_tokens,
"gen_ai.usage.output_tokens": info.answer_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.message.status": metadata.get("status"),
"dify.message.error": info.error,
"dify.message.from_source": metadata.get("from_source"),
"dify.message.from_end_user_id": metadata.get("from_end_user_id"),
"dify.message.from_account_id": metadata.get("from_account_id"),
"dify.streaming": info.is_streaming_request,
"dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token,
"dify.message.streaming_duration": info.llm_streaming_time_to_generate,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
if info.start_time and info.end_time:
attrs["dify.message.duration"] = (info.end_time - info.start_time).total_seconds()
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.MESSAGE_RUN,
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None),
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
model_provider=metadata.get("ls_provider") or "",
model_name=metadata.get("ls_model_name") or "",
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.MESSAGE,
model_provider=metadata.get("ls_provider") or "",
model_name=metadata.get("ls_model_name") or "",
node_type="",
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.message_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels)
if info.answer_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels)
invoke_from = metadata.get("from_source", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="message",
status=metadata.get("status", ""),
invoke_from=invoke_from,
),
)
if info.start_time and info.end_time:
duration = (info.end_time - info.start_time).total_seconds()
self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels)
if info.gen_ai_server_time_to_first_token is not None:
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="message",
),
)
def _tool_trace(self, info: ToolTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.tool.name": info.tool_name,
"dify.tool.duration": float(info.time_cost),
"dify.tool.status": "failed" if info.error else "succeeded",
"dify.tool.error": info.error,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref)
attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref)
attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref)
attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
tool_name=info.tool_name,
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="tool",
),
)
self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="tool",
),
)
def _moderation_trace(self, info: ModerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.moderation.flagged": info.flagged,
"dify.moderation.action": info.action,
"dify.moderation.preset_response": info.preset_response,
"dify.moderation.type": metadata.get("moderation_type", "input"),
"dify.moderation.categories": self._maybe_json(metadata.get("moderation_categories", [])),
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.moderation.query"] = self._content_or_ref(
info.query,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.MODERATION_CHECK,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="moderation",
),
)
def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
duration: float | None = None
if info.start_time is not None and info.end_time is not None:
duration = (info.end_time - info.start_time).total_seconds()
error = info.error or (info.metadata.get("error") if info.metadata else None)
status = "failed" if error else (info.status or "succeeded")
attrs.update(
{
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.suggested_question.status": status,
"dify.suggested_question.error": error,
"dify.suggested_question.duration": duration,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_id,
"dify.suggested_question.count": len(info.suggested_question),
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.suggested_question.questions"] = self._content_or_ref(
info.suggested_question,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="suggested_question",
model_provider=info.model_provider or "",
model_name=info.model_id or "",
),
)
def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.retrieval.error"] = info.error
attrs["dify.retrieval.status"] = "failed" if info.error else "succeeded"
if info.start_time and info.end_time:
attrs["dify.retrieval.duration"] = (info.end_time - info.start_time).total_seconds()
attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id")
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
docs: list[dict[str, Any]] = []
documents_any: Any = info.documents
documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else []
for entry in documents_list:
if isinstance(entry, dict):
entry_dict: dict[str, Any] = cast(dict[str, Any], entry)
docs.append(entry_dict)
dataset_ids: list[str] = []
dataset_names: list[str] = []
structured_docs: list[dict[str, Any]] = []
for doc in docs:
meta_raw = doc.get("metadata")
meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {}
did = meta.get("dataset_id")
dname = meta.get("dataset_name")
if did and did not in dataset_ids:
dataset_ids.append(did)
if dname and dname not in dataset_names:
dataset_names.append(dname)
structured_docs.append(
{
"dataset_id": did,
"document_id": meta.get("document_id"),
"segment_id": meta.get("segment_id"),
"score": meta.get("score"),
}
)
attrs["dify.dataset.id"] = self._maybe_json(dataset_ids)
attrs["dify.dataset.name"] = self._maybe_json(dataset_names)
attrs["dify.retrieval.document_count"] = len(docs)
embedding_models_raw: Any = metadata.get("embedding_models")
embedding_models: dict[str, Any] = (
cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {}
)
if embedding_models:
providers: list[str] = []
models: list[str] = []
for ds_info in embedding_models.values():
if isinstance(ds_info, dict):
ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info)
p = ds_info_dict.get("embedding_model_provider", "")
m = ds_info_dict.get("embedding_model", "")
if p and p not in providers:
providers.append(p)
if m and m not in models:
models.append(m)
attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers)
attrs["dify.dataset.embedding_models"] = self._maybe_json(models)
# Add rerank model to logs
rerank_provider = metadata.get("rerank_model_provider", "")
rerank_model = metadata.get("rerank_model_name", "")
if rerank_provider or rerank_model:
attrs["dify.retrieval.rerank_provider"] = rerank_provider
attrs["dify.retrieval.rerank_model"] = rerank_model
ref = f"ref:message_id={info.message_id}"
retrieval_inputs = self._safe_payload_value(info.inputs)
attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref)
attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL,
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None),
span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None),
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="dataset_retrieval",
),
)
for did in dataset_ids:
# Get embedding model for this specific dataset
ds_embedding_info = embedding_models.get(did, {})
embedding_provider = ds_embedding_info.get("embedding_model_provider", "")
embedding_model = ds_embedding_info.get("embedding_model", "")
# Get rerank model (same for all datasets in this retrieval)
rerank_provider = metadata.get("rerank_model_provider", "")
rerank_model = metadata.get("rerank_model_name", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.DATASET_RETRIEVALS,
1,
self._labels(
**labels,
dataset_id=did,
embedding_model_provider=embedding_provider,
embedding_model=embedding_model,
rerank_model_provider=rerank_provider,
rerank_model=rerank_model,
),
)
def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.conversation.id"] = info.conversation_id
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
duration: float | None = None
if info.start_time is not None and info.end_time is not None:
duration = (info.end_time - info.start_time).total_seconds()
error: str | None = metadata.get("error") if metadata else None
status = "failed" if error else "succeeded"
attrs["dify.generate_name.duration"] = duration
attrs["dify.generate_name.status"] = status
attrs["dify.generate_name.error"] = error
ref = f"ref:conversation_id={info.conversation_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="generate_name",
),
)
def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"gen_ai.user.id": user_id,
"dify.app_id": app_id or "",
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.prompt_generation.operation_type": info.operation_type,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.prompt_generation.duration": info.latency,
"dify.prompt_generation.status": "failed" if info.error else "succeeded",
"dify.prompt_generation.error": info.error,
}
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
if info.total_price is not None:
attrs["dify.prompt_generation.total_price"] = info.total_price
attrs["dify.prompt_generation.currency"] = info.currency
ref = f"ref:trace_id={info.trace_id}"
outputs = self._safe_payload_value(info.outputs)
attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref)
attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
node_type="",
).to_dict()
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
)
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
if info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
prompt_status = "failed" if info.error else "succeeded"
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="prompt_generation",
status=prompt_status,
),
)
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION,
info.latency,
labels,
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="prompt_generation",
),
)

View File

@@ -0,0 +1,121 @@
from enum import StrEnum
from typing import cast
from opentelemetry.util.types import AttributeValue
from pydantic import BaseModel, ConfigDict
class EnterpriseTelemetrySpan(StrEnum):
WORKFLOW_RUN = "dify.workflow.run"
NODE_EXECUTION = "dify.node.execution"
DRAFT_NODE_EXECUTION = "dify.node.execution.draft"
class EnterpriseTelemetryEvent(StrEnum):
"""Event names for enterprise telemetry logs."""
APP_CREATED = "dify.app.created"
APP_UPDATED = "dify.app.updated"
APP_DELETED = "dify.app.deleted"
FEEDBACK_CREATED = "dify.feedback.created"
WORKFLOW_RUN = "dify.workflow.run"
MESSAGE_RUN = "dify.message.run"
TOOL_EXECUTION = "dify.tool.execution"
MODERATION_CHECK = "dify.moderation.check"
SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation"
DATASET_RETRIEVAL = "dify.dataset.retrieval"
GENERATE_NAME_EXECUTION = "dify.generate_name.execution"
PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution"
REHYDRATION_FAILED = "dify.telemetry.rehydration_failed"
class EnterpriseTelemetryCounter(StrEnum):
TOKENS = "tokens"
INPUT_TOKENS = "input_tokens"
OUTPUT_TOKENS = "output_tokens"
REQUESTS = "requests"
ERRORS = "errors"
FEEDBACK = "feedback"
DATASET_RETRIEVALS = "dataset_retrievals"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
class EnterpriseTelemetryHistogram(StrEnum):
WORKFLOW_DURATION = "workflow_duration"
NODE_DURATION = "node_duration"
MESSAGE_DURATION = "message_duration"
MESSAGE_TTFT = "message_ttft"
TOOL_DURATION = "tool_duration"
PROMPT_GENERATION_DURATION = "prompt_generation_duration"
class TokenMetricLabels(BaseModel):
"""Unified label structure for all dify.token.* metrics.
All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST
use this exact label set to ensure consistent filtering and aggregation across
different operation types.
Attributes:
tenant_id: Tenant identifier.
app_id: Application identifier.
operation_type: Source of token usage (workflow | node_execution | message |
rule_generate | code_generate | structured_output | instruction_modify).
model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level).
model_name: LLM model name. Empty string if not applicable (e.g., workflow-level).
node_type: Workflow node type. Empty string unless operation_type=node_execution.
Usage:
labels = TokenMetricLabels(
tenant_id="tenant-123",
app_id="app-456",
operation_type=OperationType.WORKFLOW,
model_provider="",
model_name="",
node_type="",
)
exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS,
100,
labels.to_dict()
)
Design rationale:
Without this unified structure, tokens get double-counted when querying totals
because workflow.total_tokens is already the sum of all node tokens. The
operation_type label allows filtering to separate workflow-level aggregates from
node-level detail, while keeping the same label cardinality for consistent queries.
"""
tenant_id: str
app_id: str
operation_type: str
model_provider: str
model_name: str
node_type: str
model_config = ConfigDict(extra="forbid", frozen=True)
def to_dict(self) -> dict[str, AttributeValue]:
return cast(
dict[str, AttributeValue],
{
"tenant_id": self.tenant_id,
"app_id": self.app_id,
"operation_type": self.operation_type,
"model_provider": self.model_provider,
"model_name": self.model_name,
"node_type": self.node_type,
},
)
__all__ = [
"EnterpriseTelemetryCounter",
"EnterpriseTelemetryEvent",
"EnterpriseTelemetryHistogram",
"EnterpriseTelemetrySpan",
"TokenMetricLabels",
]

View File

@@ -0,0 +1,72 @@
"""Blinker signal handlers for enterprise telemetry.
Registered at import time via ``@signal.connect`` decorators.
Import must happen during ``ext_enterprise_telemetry.init_app()`` to
ensure handlers fire. Each handler delegates to ``core.telemetry.gateway``
which handles routing, EE-gating, and dispatch.
All handlers are best-effort: exceptions are caught and logged so that
telemetry failures never break user-facing operations.
"""
from __future__ import annotations
import logging
from events.app_event import app_was_created, app_was_deleted, app_was_updated
logger = logging.getLogger(__name__)
__all__ = [
"_handle_app_created",
"_handle_app_deleted",
"_handle_app_updated",
]
@app_was_created.connect
def _handle_app_created(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_CREATED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={
"app_id": getattr(sender, "id", None),
"mode": getattr(sender, "mode", None),
},
)
except Exception:
logger.warning("Failed to emit app_created telemetry", exc_info=True)
@app_was_updated.connect
def _handle_app_updated(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_UPDATED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={"app_id": getattr(sender, "id", None)},
)
except Exception:
logger.warning("Failed to emit app_updated telemetry", exc_info=True)
@app_was_deleted.connect
def _handle_app_deleted(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_DELETED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={"app_id": getattr(sender, "id", None)},
)
except Exception:
logger.warning("Failed to emit app_deleted telemetry", exc_info=True)

View File

@@ -0,0 +1,283 @@
"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation.
Uses dedicated TracerProvider and MeterProvider instances (configurable sampling,
independent from ext_otel.py infrastructure).
Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py).
Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process.
"""
import logging
import socket
import uuid
from datetime import UTC, datetime
from typing import Any, cast
from opentelemetry import trace
from opentelemetry.baggage import get_all
from opentelemetry.baggage.propagation import W3CBaggagePropagator
from opentelemetry.context import Context
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanContext, TraceFlags
from opentelemetry.util.types import Attributes, AttributeValue
from configs import dify_config
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
from enterprise.telemetry.id_generator import (
CorrelationIdGenerator,
compute_deterministic_span_id,
set_correlation_id,
set_span_id_source,
)
logger = logging.getLogger(__name__)
def is_enterprise_telemetry_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def _parse_otlp_headers(raw: str) -> dict[str, str]:
ctx = W3CBaggagePropagator().extract({"baggage": raw})
return {k: v for k, v in get_all(ctx).items() if isinstance(v, str)}
def _datetime_to_ns(dt: datetime) -> int:
"""Convert a datetime to nanoseconds since epoch (OTEL convention)."""
# Ensure we always interpret naive datetimes as UTC instead of local time.
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
else:
dt = dt.astimezone(UTC)
return int(dt.timestamp() * 1_000_000_000)
class _ExporterFactory:
def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool):
self._protocol = protocol
self._endpoint = endpoint
self._headers = headers
self._grpc_headers = tuple(headers.items()) if headers else None
self._http_headers = headers or None
self._insecure = insecure
def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter:
if self._protocol == "grpc":
return GRPCSpanExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=self._insecure,
)
trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else ""
return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers)
def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter:
if self._protocol == "grpc":
return GRPCMetricExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=self._insecure,
)
metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else ""
return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers)
class EnterpriseExporter:
"""Shared OTEL exporter for all enterprise telemetry.
``export_span`` creates spans with optional real timestamps, deterministic
span/trace IDs, and cross-workflow parent linking.
``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy.
"""
def __init__(self, config: object) -> None:
endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "")
headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "")
protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower()
service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify")
sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0)
self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True)
api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "")
# Auto-detect TLS: https:// uses secure, everything else is insecure
insecure = not endpoint.startswith("https://")
resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
sampler = ParentBasedTraceIdRatio(sampling_rate)
id_generator = CorrelationIdGenerator()
self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator)
headers = _parse_otlp_headers(headers_raw)
if api_key:
if "authorization" in headers:
logger.warning(
"ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains "
"'authorization'; the API key will take precedence."
)
headers["authorization"] = f"Bearer {api_key}"
factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure)
trace_exporter = factory.create_trace_exporter()
self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
self._tracer = self._tracer_provider.get_tracer("dify.enterprise")
metric_exporter = factory.create_metric_exporter()
self._meter_provider = MeterProvider(
resource=resource,
metric_readers=[PeriodicExportingMetricReader(metric_exporter)],
)
meter = self._meter_provider.get_meter("dify.enterprise")
self._counters = {
EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"),
EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"),
EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"),
EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"),
EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"),
EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"),
EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter(
"dify.dataset.retrievals.total", unit="{retrieval}"
),
EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"),
EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"),
EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"),
}
self._histograms = {
EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"),
EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram(
"dify.message.time_to_first_token", unit="s"
),
EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"),
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram(
"dify.prompt_generation.duration", unit="s"
),
}
def export_span(
self,
name: str,
attributes: dict[str, Any],
correlation_id: str | None = None,
span_id_source: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
trace_correlation_override: str | None = None,
parent_span_id_source: str | None = None,
) -> None:
"""Export an OTEL span with optional deterministic IDs and real timestamps.
Args:
name: Span operation name.
attributes: Span attributes dict.
correlation_id: Source for trace_id derivation (groups spans in one trace).
span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id).
start_time: Real span start time. When None, uses current time.
end_time: Real span end time. When None, span ends immediately.
trace_correlation_override: Override trace_id source (for cross-workflow linking).
When set, trace_id is derived from this instead of ``correlation_id``.
parent_span_id_source: Override parent span_id source (for cross-workflow linking).
When set, parent span_id is derived from this value. When None and
``correlation_id`` is set, parent is the workflow root span.
"""
effective_trace_correlation = trace_correlation_override or correlation_id
set_correlation_id(effective_trace_correlation)
set_span_id_source(span_id_source)
try:
parent_context: Context | None = None
# A span is the "root" of its correlation group when span_id_source == correlation_id
# (i.e. a workflow root span). All other spans are children.
if parent_span_id_source:
# Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow)
parent_span_id = compute_deterministic_span_id(parent_span_id_source)
try:
parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0
except (ValueError, AttributeError):
logger.warning(
"Invalid trace correlation UUID for cross-workflow link: %s, span=%s",
effective_trace_correlation,
name,
)
parent_trace_id = 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
elif correlation_id and correlation_id != span_id_source:
# Child span: parent is the correlation-group root (workflow root span)
parent_span_id = compute_deterministic_span_id(correlation_id)
try:
parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id))
except (ValueError, AttributeError):
logger.warning(
"Invalid trace correlation UUID for child span link: %s, span=%s",
effective_trace_correlation or correlation_id,
name,
)
parent_trace_id = 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
span_start_time = _datetime_to_ns(start_time) if start_time is not None else None
span_end_on_exit = end_time is None
with self._tracer.start_as_current_span(
name,
context=parent_context,
start_time=span_start_time,
end_on_exit=span_end_on_exit,
) as span:
for key, value in attributes.items():
if value is not None:
span.set_attribute(key, value)
if end_time is not None:
span.end(end_time=_datetime_to_ns(end_time))
except Exception:
logger.exception("Failed to export span %s", name)
finally:
set_correlation_id(None)
set_span_id_source(None)
def increment_counter(
self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue]
) -> None:
counter = self._counters.get(name)
if counter:
counter.add(value, cast(Attributes, labels))
def record_histogram(
self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue]
) -> None:
histogram = self._histograms.get(name)
if histogram:
histogram.record(value, cast(Attributes, labels))
def shutdown(self) -> None:
self._tracer_provider.shutdown()
self._meter_provider.shutdown()

View File

@@ -0,0 +1,75 @@
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
When a span_id_source is set, the span_id is derived deterministically
from that value, enabling any span to reference another as parent
without depending on span creation order.
"""
import random
import uuid
from contextvars import ContextVar
from opentelemetry.sdk.trace.id_generator import IdGenerator
_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None)
_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None)
def set_correlation_id(correlation_id: str | None) -> None:
_correlation_id_context.set(correlation_id)
def get_correlation_id() -> str | None:
return _correlation_id_context.get()
def set_span_id_source(source_id: str | None) -> None:
"""Set the source for deterministic span_id generation.
When set, ``generate_span_id()`` derives the span_id from this value
(lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow
root spans or ``node_execution_id`` for node spans.
"""
_span_id_source_context.set(source_id)
def compute_deterministic_span_id(source_id: str) -> int:
"""Derive a deterministic span_id from any UUID string.
Uses the lower 64 bits of the UUID, guaranteeing non-zero output
(OTEL requires span_id != 0).
"""
span_id = uuid.UUID(source_id).int & ((1 << 64) - 1)
return span_id if span_id != 0 else 1
class CorrelationIdGenerator(IdGenerator):
"""ID generator that derives trace_id and optionally span_id from context.
- trace_id: always derived from correlation_id (groups all spans in one trace)
- span_id: derived from span_id_source when set (enables deterministic
parent-child linking), otherwise random
"""
def generate_trace_id(self) -> int:
correlation_id = _correlation_id_context.get()
if correlation_id:
try:
return uuid.UUID(correlation_id).int
except (ValueError, AttributeError):
pass
return random.getrandbits(128)
def generate_span_id(self) -> int:
source = _span_id_source_context.get()
if source:
try:
return compute_deterministic_span_id(source)
except (ValueError, AttributeError):
pass
span_id = random.getrandbits(64)
while span_id == 0:
span_id = random.getrandbits(64)
return span_id

View File

@@ -0,0 +1,421 @@
"""Enterprise metric/log event handler.
This module processes metric and log telemetry events after they've been
dequeued from the enterprise_telemetry Celery queue. It handles case routing,
idempotency checking, and payload rehydration.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import Any
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
logger = logging.getLogger(__name__)
class EnterpriseMetricHandler:
"""Handler for enterprise metric and log telemetry events.
Processes envelopes from the enterprise_telemetry queue, routing each
case to the appropriate handler method. Implements idempotency checking
and payload rehydration with fallback.
"""
def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None:
"""Increment a diagnostic counter for operational monitoring.
Args:
counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total').
labels: Optional labels for the counter.
"""
try:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
return
full_counter_name = f"enterprise_telemetry.handler.{counter_name}"
logger.debug(
"Diagnostic counter: %s, labels=%s",
full_counter_name,
labels or {},
)
except Exception:
logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True)
def handle(self, envelope: TelemetryEnvelope) -> None:
"""Main entry point for processing telemetry envelopes.
Args:
envelope: The telemetry envelope to process.
"""
# Check for duplicate events
if self._is_duplicate(envelope):
logger.debug(
"Skipping duplicate event: tenant_id=%s, event_id=%s",
envelope.tenant_id,
envelope.event_id,
)
self._increment_diagnostic_counter("deduped_total")
return
# Route to appropriate handler based on case
case = envelope.case
if case == TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
elif case == TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
elif case == TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
elif case == TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
elif case == TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
elif case == TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
elif case == TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
elif case == TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
elif case == TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
elif case == TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
elif case == TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
else:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
"""Check if this event has already been processed.
Uses Redis with TTL for deduplication. Returns True if duplicate,
False if first time seeing this event.
Args:
envelope: The telemetry envelope to check.
Returns:
True if this event_id has been seen before, False otherwise.
"""
dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}"
try:
# Atomic set-if-not-exists with 1h TTL
# Returns True if key was set (first time), None if already exists (duplicate)
was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600)
return was_set is None
except Exception:
# Fail open: if Redis is unavailable, process the event
# (prefer occasional duplicate over lost data)
logger.warning(
"Redis unavailable for deduplication check, processing event anyway: %s",
envelope.event_id,
exc_info=True,
)
return False
def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]:
"""Rehydrate payload from storage reference or inline data.
If the envelope payload is empty and metadata contains a
``payload_ref``, the full payload is loaded from object storage
(where the gateway wrote it as JSON). When both the inline
payload and storage resolution fail, a degraded-event marker
is emitted so the gap is observable.
Args:
envelope: The telemetry envelope containing payload data.
Returns:
The rehydrated payload dictionary, or ``{}`` on total failure.
"""
payload = envelope.payload
# Resolve from object storage when the gateway offloaded a large payload.
if not payload and envelope.metadata:
payload_ref = envelope.metadata.get("payload_ref")
if payload_ref:
try:
payload_bytes = storage.load(payload_ref)
payload = json.loads(payload_bytes.decode("utf-8"))
logger.debug("Loaded payload from storage: key=%s", payload_ref)
except Exception:
logger.warning(
"Failed to load payload from storage: key=%s, event_id=%s",
payload_ref,
envelope.event_id,
exc_info=True,
)
if not payload:
# Storage resolution failed or no data available — emit degraded event.
logger.error(
"Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s",
envelope.event_id,
envelope.tenant_id,
envelope.case,
)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED,
attributes={
"tenant_id": envelope.tenant_id,
"dify.telemetry.error": f"Payload rehydration failed for event_id={envelope.event_id}",
"dify.telemetry.payload_type": envelope.case,
"dify.telemetry.correlation_id": envelope.event_id,
},
tenant_id=envelope.tenant_id,
)
self._increment_diagnostic_counter("rehydration_failed_total")
return {}
return payload
# Stub methods for each metric/log case
# These will be implemented in later tasks with actual emission logic
def _on_app_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle app created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app_id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.mode": payload.get("mode"),
"dify.app.created_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_CREATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_CREATED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"mode": str(payload.get("mode", "")),
},
)
def _on_app_updated(self, envelope: TelemetryEnvelope) -> None:
"""Handle app updated event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app_id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.updated_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_UPDATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_UPDATED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
},
)
def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None:
"""Handle app deleted event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app_id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.deleted_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_DELETED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_DELETED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
},
)
def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle feedback created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
include_content = exporter.include_content
attrs: dict = {
"dify.message.id": payload.get("message_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app_id": payload.get("app_id"),
"dify.conversation.id": payload.get("conversation_id"),
"gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"),
"dify.feedback.rating": payload.get("rating"),
"dify.feedback.from_source": payload.get("from_source"),
"dify.feedback.created_at": datetime.now(UTC).isoformat(),
}
if include_content:
attrs["dify.feedback.content"] = payload.get("content")
user_id = payload.get("from_end_user_id") or payload.get("from_account_id")
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
user_id=str(user_id or ""),
)
exporter.increment_counter(
EnterpriseTelemetryCounter.FEEDBACK,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"rating": str(payload.get("rating", "")),
},
)
def _on_message_run(self, envelope: TelemetryEnvelope) -> None:
"""Handle message run event.
Intentionally a no-op: metrics and structured logs for message runs are
emitted directly by EnterpriseOtelTrace._message_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id)
def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None:
"""Handle tool execution event.
Intentionally a no-op: metrics and structured logs for tool executions
are emitted directly by EnterpriseOtelTrace._tool_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id)
def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None:
"""Handle moderation check event.
Intentionally a no-op: metrics and structured logs for moderation checks
are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id)
def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None:
"""Handle suggested question event.
Intentionally a no-op: metrics and structured logs for suggested questions
are emitted directly by EnterpriseOtelTrace._suggested_question_trace at
trace time, not through the metric handler queue path.
"""
logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id)
def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None:
"""Handle dataset retrieval event.
Intentionally a no-op: metrics and structured logs for dataset retrievals
are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at
trace time, not through the metric handler queue path.
"""
logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id)
def _on_generate_name(self, envelope: TelemetryEnvelope) -> None:
"""Handle generate name event.
Intentionally a no-op: metrics and structured logs for generate name
operations are emitted directly by EnterpriseOtelTrace._generate_name_trace
at trace time, not through the metric handler queue path.
"""
logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id)
def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None:
"""Handle prompt generation event.
Intentionally a no-op: metrics and structured logs for prompt generation
operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace
at trace time, not through the metric handler queue path.
"""
logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id)

View File

@@ -0,0 +1,122 @@
"""Structured-log emitter for enterprise telemetry events.
Emits structured JSON log lines correlated with OTEL traces via trace_id.
Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic.
"""
from __future__ import annotations
import logging
import uuid
from functools import lru_cache
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
logger = logging.getLogger("dify.telemetry")
@lru_cache(maxsize=4096)
def compute_trace_id_hex(uuid_str: str | None) -> str:
"""Convert a business UUID string to a 32-hex OTEL-compatible trace_id.
Returns empty string when *uuid_str* is ``None`` or invalid.
"""
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
return f"{uuid.UUID(normalized).int:032x}"
except (ValueError, AttributeError):
return ""
@lru_cache(maxsize=4096)
def compute_span_id_hex(uuid_str: str | None) -> str:
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
return f"{compute_deterministic_span_id(normalized):016x}"
except (ValueError, AttributeError):
return ""
def emit_telemetry_log(
*,
event_name: str | EnterpriseTelemetryEvent,
attributes: dict[str, Any],
signal: str = "metric_only",
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
"""Emit a structured log line for a telemetry event.
Parameters
----------
event_name:
Canonical event name, e.g. ``"dify.workflow.run"``.
attributes:
All event-specific attributes (already built by the caller).
signal:
``"metric_only"`` for events with no span, ``"span_detail"``
for detail logs accompanying a slim span.
trace_id_source:
A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex
trace_id for cross-signal correlation.
tenant_id:
Tenant identifier (for the ``IdentityContextFilter``).
user_id:
User identifier (for the ``IdentityContextFilter``).
"""
if not logger.isEnabledFor(logging.INFO):
return
attrs = {
"dify.event.name": event_name,
"dify.event.signal": signal,
**attributes,
}
extra: dict[str, Any] = {"attributes": attrs}
trace_id_hex = compute_trace_id_hex(trace_id_source)
if trace_id_hex:
extra["trace_id"] = trace_id_hex
span_id_hex = compute_span_id_hex(span_id_source)
if span_id_hex:
extra["span_id"] = span_id_hex
if tenant_id:
extra["tenant_id"] = tenant_id
if user_id:
extra["user_id"] = user_id
logger.info("telemetry.%s", signal, extra=extra)
def emit_metric_only_event(
*,
event_name: str | EnterpriseTelemetryEvent,
attributes: dict[str, Any],
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
emit_telemetry_log(
event_name=event_name,
attributes=attributes,
signal="metric_only",
trace_id_source=trace_id_source,
span_id_source=span_id_source,
tenant_id=tenant_id,
user_id=user_id,
)

View File

@@ -11,3 +11,9 @@ app_published_workflow_was_updated = signal("app-published-workflow-was-updated"
# sender: app, kwargs: synced_draft_workflow
app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")
# sender: app
app_was_updated = signal("app-was-updated")
# sender: app
app_was_deleted = signal("app-was-deleted")

View File

@@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery:
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
}
if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED:
imports.append("tasks.enterprise_telemetry_task")
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@@ -0,0 +1,50 @@
"""Flask extension for enterprise telemetry lifecycle management.
Initializes the EnterpriseExporter singleton during ``create_app()``
(single-threaded), registers blinker event handlers, and hooks atexit
for graceful shutdown.
Skipped entirely when either ``ENTERPRISE_ENABLED`` or ``ENTERPRISE_TELEMETRY_ENABLED``
is false (``is_enabled()`` gate).
"""
from __future__ import annotations
import atexit
import logging
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from dify_app import DifyApp
from enterprise.telemetry.exporter import EnterpriseExporter
logger = logging.getLogger(__name__)
_exporter: EnterpriseExporter | None = None
def is_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def init_app(app: DifyApp) -> None:
global _exporter
if not is_enabled():
return
from enterprise.telemetry.exporter import EnterpriseExporter
_exporter = EnterpriseExporter(dify_config)
atexit.register(_exporter.shutdown)
# Import to trigger @signal.connect decorator registration
import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport]
logger.info("Enterprise telemetry initialized")
def get_enterprise_exporter() -> EnterpriseExporter | None:
return _exporter

View File

@@ -78,16 +78,24 @@ def init_app(app: DifyApp):
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
if protocol == "grpc":
# Auto-detect TLS: https:// uses secure, everything else is insecure
endpoint = dify_config.OTLP_BASE_ENDPOINT
insecure = not endpoint.startswith("https://")
# Header field names must consist of lowercase letters, check RFC7540
grpc_headers = (
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else ()
)
exporter = GRPCSpanExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
# Header field names must consist of lowercase letters, check RFC7540
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
endpoint=endpoint,
headers=grpc_headers,
insecure=insecure,
)
metric_exporter = GRPCMetricExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
endpoint=endpoint,
headers=grpc_headers,
insecure=insecure,
)
else:
headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None

View File

@@ -60,7 +60,7 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
model.node_id = data.get("node_id") or ""
model.node_type = data.get("node_type") or ""
model.status = data.get("status") or "running" # Default status if missing
model.status = WorkflowNodeExecutionStatus(data.get("status") or "running")
model.title = data.get("title") or ""
created_by_role_val = data.get("created_by_role")
try:

View File

@@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set
OpenTelemetry span attributes according to semantic conventions.
"""
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content
from extensions.otel.parser.llm import LLMNodeOTelParser
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
from extensions.otel.parser.tool import ToolNodeOTelParser
@@ -17,4 +17,5 @@ __all__ = [
"RetrievalNodeOTelParser",
"ToolNodeOTelParser",
"safe_json_dumps",
"should_include_content",
]

View File

@@ -1,5 +1,10 @@
"""
Base parser interface and utilities for OpenTelemetry node parsers.
Content gating: ``should_include_content()`` controls whether content-bearing
span attributes (inputs, outputs, prompts, completions, documents) are written.
Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when
``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged.
"""
import json
@@ -9,6 +14,7 @@ from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from pydantic import BaseModel
from configs import dify_config
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
from graphon.enums import BuiltinNodeTypes
from graphon.file.models import File
@@ -17,6 +23,16 @@ from graphon.nodes.base.node import Node
from graphon.variables import Segment
def should_include_content() -> bool:
"""Return True if content should be written to spans.
CE (ENTERPRISE_ENABLED=False): always True — no behaviour change.
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
return dify_config.ENTERPRISE_INCLUDE_CONTENT
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
"""
Safely serialize objects to JSON, handling non-serializable types.
@@ -101,10 +117,11 @@ class DefaultNodeOTelParser:
# Extract inputs and outputs from result_event
if result_event and result_event.node_run_result:
node_run_result = result_event.node_run_result
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if should_include_content():
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if error:
span.record_exception(error)

View File

@@ -21,3 +21,15 @@ class DifySpanAttributes:
INVOKE_FROM = "dify.invoke_from"
"""Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER."""
INVOKED_BY = "dify.invoked_by"
"""Invoked by, e.g. end_user, account, user."""
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
"""Number of input tokens (prompt tokens) used."""
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""Number of output tokens (completion tokens) generated."""
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
"""Total number of tokens used."""

View File

@@ -1,45 +0,0 @@
from flask_restx import fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
# Snippet list item fields (lightweight for list display)
snippet_list_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"created_at": TimestampField,
"updated_at": TimestampField,
}
# Full snippet fields (includes creator info and graph data)
snippet_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"graph": fields.Raw(attribute="graph_dict"),
"input_fields": fields.Raw(attribute="input_fields_list"),
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
"updated_at": TimestampField,
}
# Pagination response fields
snippet_pagination_fields = {
"data": fields.List(fields.Nested(snippet_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"total": fields.Integer,
"has_more": fields.Boolean,
}

View File

@@ -14,7 +14,6 @@ workflow_app_log_partial_fields = {
"id": fields.String,
"workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True),
"details": fields.Raw(attribute="details"),
"evaluation": fields.Raw(attribute="evaluation", default=None),
"created_from": fields.String,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),

View File

@@ -105,7 +105,6 @@ class WorkflowType(StrEnum):
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
SNIPPET = "snippet"
class WorkflowExecutionStatus(StrEnum):

View File

@@ -52,6 +52,12 @@ class ReadyQueueProtocol(Protocol):
...
class NodeExecutionProtocol(Protocol):
"""Structural interface for persisted per-node execution state."""
execution_id: str | None
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate.
@@ -67,6 +73,11 @@ class GraphExecutionProtocol(Protocol):
exceptions_count: int
pause_reasons: list[PauseReason]
@property
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
"""Return the persisted node execution state keyed by node id."""
...
def start(self) -> None:
"""Transition execution into the running state."""
...

View File

@@ -1,83 +0,0 @@
"""add_customized_snippets_table
Revision ID: 1c05e80d2380
Revises: 788d3099ae3a
Create Date: 2026-01-29 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = "1c05e80d2380"
down_revision = "788d3099ae3a"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
if _is_pg(conn):
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("graph", sa.Text(), nullable=True),
sa.Column("input_fields", sa.Text(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
else:
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", models.types.LongText(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
sa.Column("graph", models.types.LongText(), nullable=True),
sa.Column("input_fields", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
def downgrade():
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.drop_index("customized_snippet_tenant_idx")
op.drop_table("customized_snippets")

View File

@@ -1,116 +0,0 @@
"""add_evaluation_tables
Revision ID: a1b2c3d4e5f6
Revises: 1c05e80d2380
Create Date: 2026-03-03 00:01:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f6"
down_revision = "1c05e80d2380"
branch_labels = None
depends_on = None
def upgrade():
# evaluation_configurations
op.create_table(
"evaluation_configurations",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("target_type", sa.String(length=20), nullable=False),
sa.Column("target_id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_model_provider", sa.String(length=255), nullable=True),
sa.Column("evaluation_model", sa.String(length=255), nullable=True),
sa.Column("metrics_config", models.types.LongText(), nullable=True),
sa.Column("judgement_conditions", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
)
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
batch_op.create_index(
"evaluation_configuration_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
)
# evaluation_runs
op.create_table(
"evaluation_runs",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("target_type", sa.String(length=20), nullable=False),
sa.Column("target_id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_config_id", models.types.StringUUID(), nullable=False),
sa.Column("status", sa.String(length=20), nullable=False, server_default=sa.text("'pending'")),
sa.Column("dataset_file_id", models.types.StringUUID(), nullable=True),
sa.Column("result_file_id", models.types.StringUUID(), nullable=True),
sa.Column("total_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("completed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("failed_items", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("metrics_summary", models.types.LongText(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("celery_task_id", sa.String(length=255), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("started_at", sa.DateTime(), nullable=True),
sa.Column("completed_at", sa.DateTime(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
)
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
batch_op.create_index(
"evaluation_run_target_idx", ["tenant_id", "target_type", "target_id"], unique=False
)
batch_op.create_index("evaluation_run_status_idx", ["tenant_id", "status"], unique=False)
# evaluation_run_items
op.create_table(
"evaluation_run_items",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("evaluation_run_id", models.types.StringUUID(), nullable=False),
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
sa.Column("item_index", sa.Integer(), nullable=False),
sa.Column("inputs", models.types.LongText(), nullable=True),
sa.Column("expected_output", models.types.LongText(), nullable=True),
sa.Column("context", models.types.LongText(), nullable=True),
sa.Column("actual_output", models.types.LongText(), nullable=True),
sa.Column("metrics", models.types.LongText(), nullable=True),
sa.Column("metadata_json", models.types.LongText(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("overall_score", sa.Float(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
)
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
batch_op.create_index("evaluation_run_item_run_idx", ["evaluation_run_id"], unique=False)
batch_op.create_index(
"evaluation_run_item_index_idx", ["evaluation_run_id", "item_index"], unique=False
)
batch_op.create_index("evaluation_run_item_workflow_run_idx", ["workflow_run_id"], unique=False)
def downgrade():
with op.batch_alter_table("evaluation_run_items", schema=None) as batch_op:
batch_op.drop_index("evaluation_run_item_workflow_run_idx")
batch_op.drop_index("evaluation_run_item_index_idx")
batch_op.drop_index("evaluation_run_item_run_idx")
op.drop_table("evaluation_run_items")
with op.batch_alter_table("evaluation_runs", schema=None) as batch_op:
batch_op.drop_index("evaluation_run_status_idx")
batch_op.drop_index("evaluation_run_target_idx")
op.drop_table("evaluation_runs")
with op.batch_alter_table("evaluation_configurations", schema=None) as batch_op:
batch_op.drop_index("evaluation_configuration_target_idx")
op.drop_table("evaluation_configurations")

View File

@@ -1,25 +0,0 @@
"""merge migration heads
Revision ID: 4c60d8d3ee74
Revises: fce013ca180e, a1b2c3d4e5f6
Create Date: 2026-03-17 17:21:12.105536
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4c60d8d3ee74'
down_revision = ('fce013ca180e', 'a1b2c3d4e5f6')
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass

View File

@@ -33,13 +33,6 @@ from .enums import (
WorkflowRunTriggeredFrom,
WorkflowTriggerStatus,
)
from .evaluation import (
EvaluationConfiguration,
EvaluationRun,
EvaluationRunItem,
EvaluationRunStatus,
EvaluationTargetType,
)
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
from .human_input import HumanInputForm
from .model import (
@@ -87,7 +80,6 @@ from .provider import (
TenantDefaultModel,
TenantPreferredModelProvider,
)
from .snippet import CustomizedSnippet, SnippetType
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from .task import CeleryTask, CeleryTaskSet
from .tools import (
@@ -147,7 +139,6 @@ __all__ = [
"Conversation",
"ConversationVariable",
"CreatorUserRole",
"CustomizedSnippet",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding",
"Dataset",
@@ -165,11 +156,6 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"EvaluationConfiguration",
"EvaluationRun",
"EvaluationRunItem",
"EvaluationRunStatus",
"EvaluationTargetType",
"ExecutionExtraContent",
"ExporleBanner",
"ExternalKnowledgeApis",
@@ -197,7 +183,6 @@ __all__ = [
"RecommendedApp",
"SavedMessage",
"Site",
"SnippetType",
"Tag",
"TagBinding",
"Tenant",

View File

@@ -1,183 +0,0 @@
from __future__ import annotations
import json
from datetime import datetime
from enum import StrEnum
from typing import Any
import sqlalchemy as sa
from sqlalchemy import DateTime, Float, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import Base
from .types import LongText, StringUUID
class EvaluationRunStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class EvaluationTargetType(StrEnum):
APP = "app"
SNIPPETS = "snippets"
KNOWLEDGE_BASE = "knowledge_base"
class EvaluationConfiguration(Base):
"""Stores evaluation configuration for each target (App or Snippet)."""
__tablename__ = "evaluation_configurations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_configuration_pkey"),
sa.Index("evaluation_configuration_target_idx", "tenant_id", "target_type", "target_id"),
sa.UniqueConstraint("tenant_id", "target_type", "target_id", name="evaluation_configuration_unique"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
evaluation_model_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
evaluation_model: Mapped[str | None] = mapped_column(String(255), nullable=True)
metrics_config: Mapped[str | None] = mapped_column(LongText, nullable=True)
judgement_conditions: Mapped[str | None] = mapped_column(LongText, nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
updated_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def metrics_config_dict(self) -> dict[str, Any]:
if self.metrics_config:
return json.loads(self.metrics_config)
return {}
@metrics_config_dict.setter
def metrics_config_dict(self, value: dict[str, Any]) -> None:
self.metrics_config = json.dumps(value)
@property
def judgement_conditions_dict(self) -> dict[str, Any]:
if self.judgement_conditions:
return json.loads(self.judgement_conditions)
return {}
@judgement_conditions_dict.setter
def judgement_conditions_dict(self, value: dict[str, Any]) -> None:
self.judgement_conditions = json.dumps(value)
def __repr__(self) -> str:
return f"<EvaluationConfiguration(id={self.id}, target={self.target_type}:{self.target_id})>"
class EvaluationRun(Base):
"""Stores each evaluation run record."""
__tablename__ = "evaluation_runs"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_run_pkey"),
sa.Index("evaluation_run_target_idx", "tenant_id", "target_type", "target_id"),
sa.Index("evaluation_run_status_idx", "tenant_id", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
target_type: Mapped[str] = mapped_column(String(20), nullable=False)
target_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
evaluation_config_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
status: Mapped[str] = mapped_column(String(20), nullable=False, default=EvaluationRunStatus.PENDING)
dataset_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
result_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
total_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
completed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
failed_items: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def progress(self) -> float:
if self.total_items == 0:
return 0.0
return (self.completed_items + self.failed_items) / self.total_items
def __repr__(self) -> str:
return f"<EvaluationRun(id={self.id}, status={self.status})>"
class EvaluationRunItem(Base):
"""Stores per-row evaluation results."""
__tablename__ = "evaluation_run_items"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="evaluation_run_item_pkey"),
sa.Index("evaluation_run_item_run_idx", "evaluation_run_id"),
sa.Index("evaluation_run_item_index_idx", "evaluation_run_id", "item_index"),
sa.Index("evaluation_run_item_workflow_run_idx", "workflow_run_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
evaluation_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
item_index: Mapped[int] = mapped_column(Integer, nullable=False)
inputs: Mapped[str | None] = mapped_column(LongText, nullable=True)
expected_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
context: Mapped[str | None] = mapped_column(LongText, nullable=True)
actual_output: Mapped[str | None] = mapped_column(LongText, nullable=True)
metrics: Mapped[str | None] = mapped_column(LongText, nullable=True)
judgment: Mapped[str | None] = mapped_column(LongText, nullable=True)
metadata_json: Mapped[str | None] = mapped_column(LongText, nullable=True)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
overall_score: Mapped[float | None] = mapped_column(Float, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@property
def inputs_dict(self) -> dict[str, Any]:
if self.inputs:
return json.loads(self.inputs)
return {}
@property
def metrics_list(self) -> list[dict[str, Any]]:
if self.metrics:
return json.loads(self.metrics)
return []
@property
def judgment_dict(self) -> dict[str, Any]:
if self.judgment:
return json.loads(self.judgment)
return {}
@property
def metadata_dict(self) -> dict[str, Any]:
if self.metadata_json:
return json.loads(self.metadata_json)
return {}
def __repr__(self) -> str:
return f"<EvaluationRunItem(id={self.id}, run={self.evaluation_run_id}, index={self.item_index})>"

View File

@@ -1,101 +0,0 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import Any
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .account import Account
from .base import Base
from .engine import db
from .types import AdjustedJSON, LongText, StringUUID
class SnippetType(StrEnum):
"""Snippet Type Enum"""
NODE = "node"
GROUP = "group"
class CustomizedSnippet(Base):
"""
Customized Snippet Model
Stores reusable workflow components (nodes or node groups) that can be
shared across applications within a workspace.
"""
__tablename__ = "customized_snippets"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(LongText, nullable=True)
type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'"))
# Workflow reference for published version
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
# State flags
is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1"))
use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
# Visual customization
icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True)
# Snippet configuration (stored as JSON text)
input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True)
# Audit fields
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def graph_dict(self) -> dict[str, Any]:
"""Get graph from associated workflow."""
if self.workflow_id:
from .workflow import Workflow
workflow = db.session.get(Workflow, self.workflow_id)
if workflow:
return json.loads(workflow.graph) if workflow.graph else {}
return {}
@property
def input_fields_list(self) -> list[dict[str, Any]]:
"""Parse input_fields JSON to list."""
return json.loads(self.input_fields) if self.input_fields else []
@property
def created_by_account(self) -> Account | None:
"""Get the account that created this snippet."""
if self.created_by:
return db.session.get(Account, self.created_by)
return None
@property
def updated_by_account(self) -> Account | None:
"""Get the account that last updated this snippet."""
if self.updated_by:
return db.session.get(Account, self.updated_by)
return None
@property
def version_str(self) -> str:
"""Get version as string for API response."""
return str(self.version)

View File

@@ -33,7 +33,13 @@ from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
from graphon.enums import (
BuiltinNodeTypes,
NodeType,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from graphon.file.constants import maybe_file_object
from graphon.file.models import File
from graphon.variables import utils as variable_utils
@@ -99,7 +105,6 @@ class WorkflowType(StrEnum):
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
SNIPPET = "snippet"
@classmethod
def value_of(cls, value: str) -> "WorkflowType":
@@ -942,7 +947,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
inputs: Mapped[str | None] = mapped_column(LongText)
process_data: Mapped[str | None] = mapped_column(LongText)
outputs: Mapped[str | None] = mapped_column(LongText)
status: Mapped[str] = mapped_column(String(255))
status: Mapped[WorkflowNodeExecutionStatus] = mapped_column(EnumText(WorkflowNodeExecutionStatus, length=255))
error: Mapped[str | None] = mapped_column(LongText)
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
execution_metadata: Mapped[str | None] = mapped_column(LongText)

View File

@@ -198,12 +198,6 @@ storage = [
############################################################
tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
############################################################
# [ Evaluation ] dependency group
# Required for evaluation frameworks
############################################################
evaluation = ["ragas>=0.2.0", "deepeval>=2.0.0"]
############################################################
# [ VDB ] dependency group
# Required by vector store clients
@@ -237,26 +231,6 @@ vdb = [
"holo-search-sdk>=0.4.1",
]
[tool.mypy]
[[tool.mypy.overrides]]
# targeted ignores for current type-check errors
# TODO(QuantumGhost): suppress type errors in HITL related code.
# fix the type error later
module = [
"configs.middleware.cache.redis_pubsub_config",
"extensions.ext_redis",
"tasks.workflow_execution_tasks",
"graphon.nodes.base.node",
"services.human_input_delivery_test_service",
"core.app.apps.advanced_chat.app_generator",
"controllers.console.human_input_form",
"controllers.console.app.workflow_run",
"repositories.sqlalchemy_api_workflow_node_execution_repository",
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
]
ignore_errors = true
[tool.pyrefly]
project-includes = ["."]
project-excludes = [".venv", "migrations/"]

View File

@@ -109,6 +109,15 @@ core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py
core/trigger/provider.py
core/workflow/workflow_entry.py
enterprise/telemetry/contracts.py
enterprise/telemetry/draft_trace.py
enterprise/telemetry/enterprise_trace.py
enterprise/telemetry/entities/__init__.py
enterprise/telemetry/event_handlers.py
enterprise/telemetry/exporter.py
enterprise/telemetry/id_generator.py
enterprise/telemetry/metric_handler.py
enterprise/telemetry/telemetry_log.py
graphon/entities/workflow_execution.py
graphon/file/file_manager.py
graphon/graph_engine/error_handler.py

View File

@@ -12,7 +12,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -281,6 +281,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def update_app_name(self, app: App, name: str) -> App:
@@ -296,6 +298,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def update_app_icon(self, app: App, icon: str, icon_background: str) -> App:
@@ -313,6 +317,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def update_app_site_status(self, app: App, enable_site: bool) -> App:
@@ -330,6 +336,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def update_app_api_status(self, app: App, enable_api: bool) -> App:
@@ -348,6 +356,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def delete_app(self, app: App):
@@ -355,6 +365,8 @@ class AppService:
Delete app
:param app: App instance
"""
app_was_deleted.send(app)
db.session.delete(app)
db.session.commit()

View File

@@ -1,21 +0,0 @@
from services.errors.base import BaseServiceError
class EvaluationFrameworkNotConfiguredError(BaseServiceError):
def __init__(self, description: str | None = None):
super().__init__(description or "Evaluation framework is not configured. Set EVALUATION_FRAMEWORK env var.")
class EvaluationNotFoundError(BaseServiceError):
def __init__(self, description: str | None = None):
super().__init__(description or "Evaluation not found.")
class EvaluationDatasetInvalidError(BaseServiceError):
def __init__(self, description: str | None = None):
super().__init__(description or "Evaluation dataset is invalid.")
class EvaluationMaxConcurrentRunsError(BaseServiceError):
def __init__(self, description: str | None = None):
super().__init__(description or "Maximum number of concurrent evaluation runs reached.")

View File

@@ -1,897 +0,0 @@
import io
import json
import logging
from collections.abc import Mapping
from typing import Any, Union
from openpyxl import Workbook, load_workbook
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
from openpyxl.utils import get_column_letter
from sqlalchemy.orm import Session
from configs import dify_config
from core.evaluation.entities.evaluation_entity import (
METRIC_NODE_TYPE_MAPPING,
DefaultMetric,
EvaluationCategory,
EvaluationConfigData,
EvaluationDatasetInput,
EvaluationMetricName,
EvaluationRunData,
EvaluationRunRequest,
NodeInfo,
)
from core.evaluation.evaluation_manager import EvaluationManager
from graphon.enums import WorkflowNodeExecutionMetadataKey
from graphon.node_events.base import NodeRunResult
from models.evaluation import (
EvaluationConfiguration,
EvaluationRun,
EvaluationRunItem,
EvaluationRunStatus,
)
from models.model import App, AppMode
from models.snippet import CustomizedSnippet
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.snippet_service import SnippetService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class EvaluationService:
"""
Service for evaluation-related operations.
Provides functionality to generate evaluation dataset templates
based on App or Snippet input parameters.
"""
# Excluded app modes that don't support evaluation templates
EXCLUDED_APP_MODES = {AppMode.RAG_PIPELINE}
@classmethod
def generate_dataset_template(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
) -> tuple[bytes, str]:
"""
Generate evaluation dataset template as XLSX bytes.
Creates an XLSX file with headers based on the evaluation target's input parameters.
The first column is index, followed by input parameter columns.
:param target: App or CustomizedSnippet instance
:param target_type: Target type string ("app" or "snippet")
:return: Tuple of (xlsx_content_bytes, filename)
:raises ValueError: If target type is not supported or app mode is excluded
"""
# Validate target type
if target_type == "app":
if not isinstance(target, App):
raise ValueError("Invalid target: expected App instance")
if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES:
raise ValueError(f"App mode '{target.mode}' does not support evaluation templates")
input_fields = cls._get_app_input_fields(target)
elif target_type == "snippet":
if not isinstance(target, CustomizedSnippet):
raise ValueError("Invalid target: expected CustomizedSnippet instance")
input_fields = cls._get_snippet_input_fields(target)
else:
raise ValueError(f"Unsupported target type: {target_type}")
# Generate XLSX template
xlsx_content = cls._generate_xlsx_template(input_fields, target.name)
# Build filename
truncated_name = target.name[:10] + "..." if len(target.name) > 10 else target.name
filename = f"{truncated_name}-evaluation-dataset.xlsx"
return xlsx_content, filename
@classmethod
def _get_app_input_fields(cls, app: App) -> list[dict]:
"""
Get input fields from App's workflow.
:param app: App instance
:return: List of input field definitions
"""
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app)
if not workflow:
workflow = workflow_service.get_draft_workflow(app_model=app)
if not workflow:
return []
# Get user input form from workflow
user_input_form = workflow.user_input_form()
return user_input_form
@classmethod
def _get_snippet_input_fields(cls, snippet: CustomizedSnippet) -> list[dict]:
"""
Get input fields from Snippet.
Tries to get from snippet's own input_fields first,
then falls back to workflow's user_input_form.
:param snippet: CustomizedSnippet instance
:return: List of input field definitions
"""
# Try snippet's own input_fields first
input_fields = snippet.input_fields_list
if input_fields:
return input_fields
# Fallback to workflow's user_input_form
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
if not workflow:
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if workflow:
return workflow.user_input_form()
return []
@classmethod
def _generate_xlsx_template(cls, input_fields: list[dict], target_name: str) -> bytes:
"""
Generate XLSX template file content.
Creates a workbook with:
- First row as header row with "index" and input field names
- Styled header with background color and borders
- Empty data rows ready for user input
:param input_fields: List of input field definitions
:param target_name: Name of the target (for sheet name)
:return: XLSX file content as bytes
"""
wb = Workbook()
ws = wb.active
if ws is None:
ws = wb.create_sheet("Evaluation Dataset")
sheet_name = "Evaluation Dataset"
ws.title = sheet_name
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin"),
)
# Build header row
headers = ["index"]
for field in input_fields:
field_label = str(field.get("label") or field.get("variable") or "")
headers.append(field_label)
# Write header row
for col_idx, header in enumerate(headers, start=1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
# Set column widths
ws.column_dimensions["A"].width = 10 # index column
for col_idx in range(2, len(headers) + 1):
ws.column_dimensions[get_column_letter(col_idx)].width = 20
# Add one empty row with row number for user reference
for col_idx in range(1, len(headers) + 1):
cell = ws.cell(row=2, column=col_idx, value="")
cell.border = thin_border
if col_idx == 1:
cell.value = 1
cell.alignment = Alignment(horizontal="center")
# Save to bytes
output = io.BytesIO()
wb.save(output)
output.seek(0)
return output.getvalue()
@classmethod
def generate_retrieval_dataset_template(cls) -> tuple[bytes, str]:
"""Generate evaluation dataset XLSX template for knowledge base retrieval.
The template contains three columns: ``index``, ``query``, and
``expected_output``. Callers upload a filled copy and start an
evaluation run with ``target_type="dataset"``.
:returns: (xlsx_content_bytes, filename)
"""
wb = Workbook()
ws = wb.active
if ws is None:
ws = wb.create_sheet("Evaluation Dataset")
ws.title = "Evaluation Dataset"
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin"),
)
headers = ["index", "query", "expected_output"]
for col_idx, header in enumerate(headers, start=1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
ws.column_dimensions["A"].width = 10
ws.column_dimensions["B"].width = 30
ws.column_dimensions["C"].width = 30
# Add one sample row
for col_idx in range(1, len(headers) + 1):
cell = ws.cell(row=2, column=col_idx, value="")
cell.border = thin_border
if col_idx == 1:
cell.value = 1
cell.alignment = Alignment(horizontal="center")
output = io.BytesIO()
wb.save(output)
output.seek(0)
return output.getvalue(), "retrieval-evaluation-dataset.xlsx"
# ---- Evaluation Configuration CRUD ----
@classmethod
def get_evaluation_config(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
) -> EvaluationConfiguration | None:
return (
session.query(EvaluationConfiguration)
.filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id)
.first()
)
@classmethod
def save_evaluation_config(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
account_id: str,
data: EvaluationConfigData,
) -> EvaluationConfiguration:
config = cls.get_evaluation_config(session, tenant_id, target_type, target_id)
if config is None:
config = EvaluationConfiguration(
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
created_by=account_id,
updated_by=account_id,
)
session.add(config)
config.evaluation_model_provider = data.evaluation_model_provider
config.evaluation_model = data.evaluation_model
config.metrics_config = json.dumps(
{
"default_metrics": [m.model_dump() for m in data.default_metrics],
"customized_metrics": data.customized_metrics.model_dump() if data.customized_metrics else None,
}
)
config.judgement_conditions = json.dumps(data.judgment_config.model_dump() if data.judgment_config else {})
config.updated_by = account_id
session.commit()
session.refresh(config)
return config
# ---- Evaluation Run Management ----
@classmethod
def start_evaluation_run(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
account_id: str,
dataset_file_content: bytes,
run_request: EvaluationRunRequest,
) -> EvaluationRun:
"""Validate dataset, create run record, dispatch Celery task.
Saves the provided parameters as the latest EvaluationConfiguration
before creating the run.
"""
# Check framework is configured
evaluation_instance = EvaluationManager.get_evaluation_instance()
if evaluation_instance is None:
raise EvaluationFrameworkNotConfiguredError()
# Save as latest EvaluationConfiguration
config = cls.save_evaluation_config(
session=session,
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
account_id=account_id,
data=run_request,
)
# Check concurrent run limit
active_runs = (
session.query(EvaluationRun)
.filter_by(tenant_id=tenant_id)
.filter(EvaluationRun.status.in_([EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING]))
.count()
)
max_concurrent = dify_config.EVALUATION_MAX_CONCURRENT_RUNS
if active_runs >= max_concurrent:
raise EvaluationMaxConcurrentRunsError(f"Maximum concurrent runs ({max_concurrent}) reached.")
# Parse dataset
items = cls._parse_dataset(dataset_file_content)
max_rows = dify_config.EVALUATION_MAX_DATASET_ROWS
if len(items) > max_rows:
raise EvaluationDatasetInvalidError(f"Dataset has {len(items)} rows, max is {max_rows}.")
# Create evaluation run
evaluation_run = EvaluationRun(
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
evaluation_config_id=config.id,
status=EvaluationRunStatus.PENDING,
total_items=len(items),
created_by=account_id,
)
session.add(evaluation_run)
session.commit()
session.refresh(evaluation_run)
# Build Celery task data
run_data = EvaluationRunData(
evaluation_run_id=evaluation_run.id,
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
evaluation_model_provider=run_request.evaluation_model_provider,
evaluation_model=run_request.evaluation_model,
default_metrics=run_request.default_metrics,
customized_metrics=run_request.customized_metrics,
judgment_config=run_request.judgment_config,
input_list=items,
)
# Dispatch Celery task
from tasks.evaluation_task import run_evaluation
task = run_evaluation.delay(run_data.model_dump())
evaluation_run.celery_task_id = task.id
session.commit()
return evaluation_run
@classmethod
def get_evaluation_runs(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
page: int = 1,
page_size: int = 20,
) -> tuple[list[EvaluationRun], int]:
"""Query evaluation run history with pagination."""
query = (
session.query(EvaluationRun)
.filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id)
.order_by(EvaluationRun.created_at.desc())
)
total = query.count()
runs = query.offset((page - 1) * page_size).limit(page_size).all()
return runs, total
@classmethod
def get_evaluation_run_detail(
cls,
session: Session,
tenant_id: str,
run_id: str,
) -> EvaluationRun:
run = session.query(EvaluationRun).filter_by(id=run_id, tenant_id=tenant_id).first()
if not run:
raise EvaluationNotFoundError("Evaluation run not found.")
return run
@classmethod
def get_evaluation_run_items(
cls,
session: Session,
run_id: str,
page: int = 1,
page_size: int = 50,
) -> tuple[list[EvaluationRunItem], int]:
"""Query evaluation run items with pagination."""
query = (
session.query(EvaluationRunItem)
.filter_by(evaluation_run_id=run_id)
.order_by(EvaluationRunItem.item_index.asc())
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
@classmethod
def cancel_evaluation_run(
cls,
session: Session,
tenant_id: str,
run_id: str,
) -> EvaluationRun:
run = cls.get_evaluation_run_detail(session, tenant_id, run_id)
if run.status not in (EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING):
raise ValueError(f"Cannot cancel evaluation run in status: {run.status}")
run.status = EvaluationRunStatus.CANCELLED
# Revoke Celery task if running
if run.celery_task_id:
try:
from celery import current_app as celery_app
celery_app.control.revoke(run.celery_task_id, terminate=True)
except Exception:
logger.exception("Failed to revoke Celery task %s", run.celery_task_id)
session.commit()
return run
@classmethod
def get_supported_metrics(cls, category: EvaluationCategory) -> list[str]:
return EvaluationManager.get_supported_metrics(category)
@staticmethod
def get_available_metrics() -> list[str]:
"""Return the centrally-defined list of evaluation metrics."""
return [m.value for m in EvaluationMetricName]
@classmethod
def get_nodes_for_metrics(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
metrics: list[str] | None = None,
) -> dict[str, list[dict[str, str]]]:
"""Return node info grouped by metric (or all nodes when *metrics* is empty).
:param target: App or CustomizedSnippet instance.
:param target_type: ``"app"`` or ``"snippets"``.
:param metrics: Optional list of metric names to filter by.
When *None* or empty, returns ``{"all": [<every node>]}``.
:returns: ``{metric_name: [NodeInfo dict, ...]}`` or
``{"all": [NodeInfo dict, ...]}``.
"""
workflow = cls._resolve_workflow(target, target_type)
if not workflow:
return {"all": []} if not metrics else {m: [] for m in metrics}
if not metrics:
all_nodes = [
NodeInfo(node_id=node_id, type=node_data.get("type", ""), title=node_data.get("title", "")).model_dump()
for node_id, node_data in workflow.walk_nodes()
]
return {"all": all_nodes}
node_type_to_nodes: dict[str, list[dict[str, str]]] = {}
for node_id, node_data in workflow.walk_nodes():
ntype = node_data.get("type", "")
node_type_to_nodes.setdefault(ntype, []).append(
NodeInfo(node_id=node_id, type=ntype, title=node_data.get("title", "")).model_dump()
)
result: dict[str, list[dict[str, str]]] = {}
for metric in metrics:
required_node_type = METRIC_NODE_TYPE_MAPPING.get(metric)
if required_node_type is None:
result[metric] = []
continue
result[metric] = node_type_to_nodes.get(required_node_type, [])
return result
@classmethod
def _resolve_workflow(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
) -> "Workflow | None":
"""Resolve the *published* (preferred) or *draft* workflow for the target."""
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=target)
if not workflow:
workflow = snippet_service.get_draft_workflow(snippet=target)
return workflow
elif target_type == "app" and isinstance(target, App):
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=target)
if not workflow:
workflow = workflow_service.get_draft_workflow(app_model=target)
return workflow
return None
# ---- Category Resolution ----
@classmethod
def _resolve_evaluation_category(cls, default_metrics: list[DefaultMetric]) -> EvaluationCategory:
"""Derive evaluation category from default_metrics node_info types.
Uses the type of the first node_info found in default_metrics.
Falls back to LLM if no metrics are provided.
"""
for metric in default_metrics:
for node_info in metric.node_info_list:
try:
return EvaluationCategory(node_info.type)
except ValueError:
continue
return EvaluationCategory.LLM
@classmethod
def execute_targets(
cls,
tenant_id: str,
target_type: str,
target_id: str,
input_list: list[EvaluationDatasetInput],
max_workers: int = 5,
) -> tuple[list[dict[str, NodeRunResult]], list[str | None]]:
"""Execute the evaluation target for every test-data item in parallel.
:param tenant_id: Workspace / tenant ID.
:param target_type: ``"app"`` or ``"snippet"``.
:param target_id: ID of the App or CustomizedSnippet.
:param input_list: All test-data items parsed from the dataset.
:param max_workers: Maximum number of parallel worker threads.
:return: Tuple of (node_results, workflow_run_ids).
node_results: ordered list of ``{node_id: NodeRunResult}`` mappings;
the *i*-th element corresponds to ``input_list[i]``.
workflow_run_ids: ordered list of workflow_run_id strings (or None)
for each input item.
"""
from concurrent.futures import ThreadPoolExecutor
from flask import Flask, current_app
flask_app: Flask = current_app._get_current_object() # type: ignore
def _worker(item: EvaluationDatasetInput) -> tuple[dict[str, NodeRunResult], str | None]:
with flask_app.app_context():
from models.engine import db
with Session(db.engine, expire_on_commit=False) as thread_session:
try:
response = cls._run_single_target(
session=thread_session,
target_type=target_type,
target_id=target_id,
item=item,
)
workflow_run_id = cls._extract_workflow_run_id(response)
if not workflow_run_id:
logger.warning(
"No workflow_run_id for item %d (target=%s)",
item.index,
target_id,
)
return {}, None
node_results = cls._query_node_run_results(
session=thread_session,
tenant_id=tenant_id,
app_id=target_id,
workflow_run_id=workflow_run_id,
)
return node_results, workflow_run_id
except Exception:
logger.exception(
"Target execution failed for item %d (target=%s)",
item.index,
target_id,
)
return {}, None
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_worker, item) for item in input_list]
ordered_results: list[dict[str, NodeRunResult]] = []
ordered_workflow_run_ids: list[str | None] = []
for future in futures:
try:
node_result, wf_run_id = future.result()
ordered_results.append(node_result)
ordered_workflow_run_ids.append(wf_run_id)
except Exception:
logger.exception("Unexpected error collecting target execution result")
ordered_results.append({})
ordered_workflow_run_ids.append(None)
return ordered_results, ordered_workflow_run_ids
@classmethod
def _run_single_target(
cls,
session: Session,
target_type: str,
target_id: str,
item: EvaluationDatasetInput,
) -> Mapping[str, object]:
"""Execute a single evaluation target with one test-data item.
Dispatches to the appropriate execution service based on
``target_type``:
* ``"snippet"`` → :meth:`SnippetGenerateService.run_published`
* ``"app"`` → :meth:`WorkflowAppGenerator().generate` (blocking mode)
:returns: The blocking response mapping from the workflow engine.
:raises ValueError: If the target is not found or not published.
"""
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_app, get_service_account_for_snippet
if target_type == "snippet":
from services.snippet_generate_service import SnippetGenerateService
snippet = session.query(CustomizedSnippet).filter_by(id=target_id).first()
if not snippet:
raise ValueError(f"Snippet {target_id} not found")
service_account = get_service_account_for_snippet(session, target_id)
return SnippetGenerateService.run_published(
snippet=snippet,
user=service_account,
args={"inputs": item.inputs},
invoke_from=InvokeFrom.SERVICE_API,
)
else:
# target_type == "app"
app = session.query(App).filter_by(id=target_id).first()
if not app:
raise ValueError(f"App {target_id} not found")
service_account = get_service_account_for_app(session, target_id)
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app)
if not workflow:
raise ValueError(f"No published workflow for app {target_id}")
response: Mapping[str, object] = WorkflowAppGenerator().generate(
app_model=app,
workflow=workflow,
user=service_account,
args={"inputs": item.inputs},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
)
return response
@staticmethod
def _extract_workflow_run_id(response: Mapping[str, object]) -> str | None:
"""Extract ``workflow_run_id`` from a blocking workflow response."""
wf_run_id = response.get("workflow_run_id")
if wf_run_id:
return str(wf_run_id)
data = response.get("data")
if isinstance(data, Mapping) and data.get("id"):
return str(data["id"])
return None
@staticmethod
def _query_node_run_results(
session: Session,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> dict[str, NodeRunResult]:
"""Query all node execution records for a workflow run."""
from sqlalchemy import asc, select
from graphon.enums import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
stmt = (
WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel))
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
.order_by(asc(WorkflowNodeExecutionModel.created_at))
)
node_models: list[WorkflowNodeExecutionModel] = list(session.execute(stmt).scalars().all())
result: dict[str, NodeRunResult] = {}
for node in node_models:
# Convert string-keyed metadata to WorkflowNodeExecutionMetadataKey-keyed
raw_metadata = node.execution_metadata_dict
typed_metadata: dict[WorkflowNodeExecutionMetadataKey, object] = {}
for key, val in raw_metadata.items():
try:
typed_metadata[WorkflowNodeExecutionMetadataKey(key)] = val
except ValueError:
pass # skip unknown metadata keys
result[node.node_id] = NodeRunResult(
status=WorkflowNodeExecutionStatus(node.status),
inputs=node.inputs_dict or {},
process_data=node.process_data_dict or {},
outputs=node.outputs_dict or {},
metadata=typed_metadata,
error=node.error or "",
)
return result
# ---- Dataset Parsing ----
@classmethod
def _parse_dataset(cls, xlsx_content: bytes) -> list[EvaluationDatasetInput]:
"""Parse evaluation dataset from XLSX bytes."""
wb = load_workbook(io.BytesIO(xlsx_content), read_only=True)
ws = wb.active
if ws is None:
raise EvaluationDatasetInvalidError("XLSX file has no active worksheet.")
rows = list(ws.iter_rows(values_only=True))
if len(rows) < 2:
raise EvaluationDatasetInvalidError("Dataset must have at least a header row and one data row.")
headers = [str(h).strip() if h is not None else "" for h in rows[0]]
if not headers or headers[0].lower() != "index":
raise EvaluationDatasetInvalidError("First column header must be 'index'.")
input_headers = headers[1:] # Skip 'index'
items = []
for row_idx, row in enumerate(rows[1:], start=1):
values = list(row)
if all(v is None or str(v).strip() == "" for v in values):
continue # Skip empty rows
index_val = values[0] if values else row_idx
try:
index = int(str(index_val))
except (TypeError, ValueError):
index = row_idx
inputs: dict[str, Any] = {}
for col_idx, header in enumerate(input_headers):
val = values[col_idx + 1] if col_idx + 1 < len(values) else None
inputs[header] = str(val) if val is not None else ""
# Extract expected_output column into dedicated field
expected_output = inputs.pop("expected_output", None)
items.append(
EvaluationDatasetInput(
index=index,
inputs=inputs,
expected_output=expected_output,
)
)
wb.close()
return items
@classmethod
def execute_retrieval_test_targets(
cls,
dataset_id: str,
account_id: str,
input_list: list[EvaluationDatasetInput],
max_workers: int = 5,
) -> list[NodeRunResult]:
"""Run hit testing against a knowledge base for every input item in parallel.
Each item must supply a ``query`` key in its ``inputs`` dict. The
retrieved segments are normalised into the same ``NodeRunResult`` format
that :class:`RetrievalEvaluationRunner` expects:
.. code-block:: python
NodeRunResult(
inputs={"query": "..."},
outputs={"result": [{"content": "...", "score": ...}, ...]},
)
:returns: Ordered list of ``NodeRunResult`` — one per input item.
If retrieval fails for an item the result has an empty ``result``
list so the runner can still persist a (metric-less) row.
"""
from concurrent.futures import ThreadPoolExecutor
from flask import current_app
flask_app = current_app._get_current_object() # type: ignore
def _worker(item: EvaluationDatasetInput) -> NodeRunResult:
with flask_app.app_context():
from extensions.ext_database import db as flask_db
from models.account import Account
from models.dataset import Dataset
from services.hit_testing_service import HitTestingService
dataset = flask_db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError(f"Dataset {dataset_id} not found")
account = flask_db.session.query(Account).filter_by(id=account_id).first()
if not account:
raise ValueError(f"Account {account_id} not found")
query = str(item.inputs.get("query", ""))
response = HitTestingService.retrieve(
dataset=dataset,
query=query,
account=account,
retrieval_model=None, # Use dataset's configured retrieval model
external_retrieval_model={},
limit=10,
)
records = response.get("records", [])
result_list = [
{
"content": r.get("segment", {}).get("content", "") or r.get("content", ""),
"score": r.get("score"),
}
for r in records
if r.get("segment", {}).get("content") or r.get("content")
]
return NodeRunResult(
inputs={"query": query},
outputs={"result": result_list},
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_worker, item) for item in input_list]
results: list[NodeRunResult] = []
for item, future in zip(input_list, futures):
try:
results.append(future.result())
except Exception:
logger.exception("Retrieval test failed for item %d (dataset=%s)", item.index, dataset_id)
results.append(NodeRunResult(inputs={}, outputs={"result": []}))
return results

View File

@@ -220,7 +220,7 @@ class EmailDeliveryTestHandler:
stmt = stmt.where(Account.id.in_(unique_ids))
with self._session_factory() as session:
rows = session.execute(stmt).all()
rows = session.execute(stmt).tuples().all()
return dict(rows)
@staticmethod

View File

@@ -46,6 +46,7 @@ from core.workflow.system_variables import (
)
from core.workflow.variable_pool_initializer import add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
from extensions.ext_database import db
from graphon.entities.workflow_node_execution import (
WorkflowNodeExecution,
@@ -577,6 +578,13 @@ class RagPipelineService:
outputs=workflow_node_execution.outputs,
)
session.commit()
if workflow_node_execution_db_model is not None:
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
outputs=workflow_node_execution.outputs,
workflow_execution_id=None,
user_id=account.id,
)
return workflow_node_execution_db_model
def run_datasource_workflow_node(
@@ -1339,6 +1347,12 @@ class RagPipelineService:
outputs=workflow_node_execution.outputs,
)
session.commit()
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
outputs=workflow_node_execution.outputs,
workflow_execution_id=None,
user_id=current_user.id,
)
return workflow_node_execution_db_model
def get_recommended_plugins(self, type: str) -> dict:

View File

@@ -1,570 +0,0 @@
import json
import logging
import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from urllib.parse import urlparse
import yaml # type: ignore
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.plugin.entities.plugin import PluginDependency
from graphon.enums import BuiltinNodeTypes
from graphon.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.snippet import CustomizedSnippet, SnippetType
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "snippet_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "snippet_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.1.0"
# List of node types that are not allowed in snippets
FORBIDDEN_NODE_TYPES = [
BuiltinNodeTypes.START,
BuiltinNodeTypes.HUMAN_INPUT,
]
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class SnippetImportInfo(BaseModel):
id: str
status: ImportStatus
snippet_id: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
current_ver = version.parse(CURRENT_DSL_VERSION)
imported_ver = version.parse(imported_version)
except version.InvalidVersion:
return ImportStatus.FAILED
# If imported version is newer than current, always return PENDING
if imported_ver > current_ver:
return ImportStatus.PENDING
# If imported version is older than current's major, return PENDING
if imported_ver.major < current_ver.major:
return ImportStatus.PENDING
# If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
if imported_ver.minor < current_ver.minor:
return ImportStatus.COMPLETED_WITH_WARNINGS
# If imported version equals or is older than current's micro, return COMPLETED
return ImportStatus.COMPLETED
class SnippetPendingData(BaseModel):
import_mode: str
yaml_content: str
snippet_id: str | None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
snippet_id: str | None
class SnippetDslService:
def __init__(self, session: Session):
self._session = session
def import_snippet(
self,
*,
account: Account,
import_mode: str,
yaml_content: str | None = None,
yaml_url: str | None = None,
snippet_id: str | None = None,
name: str | None = None,
description: str | None = None,
) -> SnippetImportInfo:
"""Import a snippet from YAML content or URL."""
import_id = str(uuid.uuid4())
# Validate import mode
try:
mode = ImportMode(import_mode)
except ValueError:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_url is required when import_mode is yaml-url",
)
try:
parsed_url = urlparse(yaml_url)
if parsed_url.scheme not in ["http", "https"]:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid URL scheme, only http and https are allowed",
)
response = ssrf_proxy.get(yaml_url, timeout=(10, 30))
if response.status_code != 200:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Failed to fetch YAML from URL: {response.status_code}",
)
content = response.text
if len(content) > DSL_MAX_SIZE:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes",
)
except Exception as e:
logger.exception("Failed to fetch YAML from URL")
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Failed to fetch YAML from URL: {str(e)}",
)
elif mode == ImportMode.YAML_CONTENT:
if not yaml_content:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_content is required when import_mode is yaml-content",
)
content = yaml_content
if len(content) > DSL_MAX_SIZE:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes",
)
try:
# Parse YAML
data = yaml.safe_load(content)
if not isinstance(data, dict):
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid YAML format: expected a dictionary",
)
# Validate and fix DSL version
if not data.get("version"):
data["version"] = "0.1.0"
# Strictly validate kind field
kind = data.get("kind")
if not kind:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Missing 'kind' field in DSL. Expected 'kind: snippet'.",
)
if kind != "snippet":
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid DSL kind: expected 'snippet', got '{kind}'. This DSL is for {kind}, not snippet.",
)
imported_version = data.get("version", "0.1.0")
if not isinstance(imported_version, str):
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
status = _check_version_compatibility(imported_version)
# Extract snippet data
snippet_data = data.get("snippet")
if not snippet_data:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Missing snippet data in YAML content",
)
# Validate workflow nodes - check for forbidden node types
workflow_data = data.get("workflow", {})
if workflow_data:
graph = workflow_data.get("graph", {})
nodes = graph.get("nodes", [])
forbidden_nodes_found = []
for node in nodes:
node_data = node.get("data", {})
if not node_data:
continue
node_type = node_data.get("type", "")
if node_type in FORBIDDEN_NODE_TYPES:
forbidden_nodes_found.append(node_type)
if forbidden_nodes_found:
forbidden_types_str = ", ".join(set(forbidden_nodes_found))
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Snippet cannot contain the following node types: {forbidden_types_str}",
)
# If snippet_id is provided, check if it exists
snippet = None
if snippet_id:
stmt = select(CustomizedSnippet).where(
CustomizedSnippet.id == snippet_id,
CustomizedSnippet.tenant_id == account.current_tenant_id,
)
snippet = self._session.scalar(stmt)
if not snippet:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Snippet not found",
)
# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
pending_data = SnippetPendingData(
import_mode=import_mode,
yaml_content=content,
snippet_id=snippet_id,
)
redis_client.setex(
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
IMPORT_INFO_REDIS_EXPIRY,
pending_data.model_dump_json(),
)
return SnippetImportInfo(
id=import_id,
status=status,
snippet_id=snippet_id,
imported_dsl_version=imported_version,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
# Create or update snippet
snippet = self._create_or_update_snippet(
snippet=snippet,
data=data,
account=account,
name=name,
description=description,
dependencies=check_dependencies_pending_data,
)
return SnippetImportInfo(
id=import_id,
status=status,
snippet_id=snippet.id,
imported_dsl_version=imported_version,
)
except yaml.YAMLError as e:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid YAML format: {str(e)}",
)
except Exception as e:
logger.exception("Failed to import snippet")
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def confirm_import(self, *, import_id: str, account: Account) -> SnippetImportInfo:
"""
Confirm an import that requires confirmation
"""
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
pending_data = redis_client.get(redis_key)
if not pending_data:
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Import information expired or does not exist",
)
try:
if not isinstance(pending_data, str | bytes):
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid import information",
)
pending_data_str = pending_data.decode("utf-8") if isinstance(pending_data, bytes) else pending_data
pending = SnippetPendingData.model_validate_json(pending_data_str)
# Re-import with the pending data
return self.import_snippet(
account=account,
import_mode=pending.import_mode,
yaml_content=pending.yaml_content,
snippet_id=pending.snippet_id,
)
except Exception as e:
logger.exception("Failed to confirm import")
return SnippetImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def check_dependencies(self, snippet: CustomizedSnippet) -> CheckDependenciesResult:
"""
Check dependencies for a snippet
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
return CheckDependenciesResult(leaked_dependencies=[])
dependencies = self._extract_dependencies_from_workflow(workflow)
leaked_dependencies = DependenciesAnalysisService.generate_dependencies(
tenant_id=snippet.tenant_id, dependencies=dependencies
)
return CheckDependenciesResult(leaked_dependencies=leaked_dependencies)
def _create_or_update_snippet(
self,
*,
snippet: CustomizedSnippet | None,
data: dict,
account: Account,
name: str | None = None,
description: str | None = None,
dependencies: list[PluginDependency] | None = None,
) -> CustomizedSnippet:
"""
Create or update snippet from DSL data
"""
snippet_data = data.get("snippet", {})
workflow_data = data.get("workflow", {})
# Extract snippet info
snippet_name = name or snippet_data.get("name") or "Untitled Snippet"
snippet_description = description or snippet_data.get("description") or ""
snippet_type_str = snippet_data.get("type", "node")
try:
snippet_type = SnippetType(snippet_type_str)
except ValueError:
snippet_type = SnippetType.NODE
icon_info = snippet_data.get("icon_info", {})
input_fields = snippet_data.get("input_fields", [])
# Create or update snippet
if snippet:
# Update existing snippet
snippet.name = snippet_name
snippet.description = snippet_description
snippet.type = snippet_type.value
snippet.icon_info = icon_info or None
snippet.input_fields = json.dumps(input_fields) if input_fields else None
snippet.updated_by = account.id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
else:
# Create new snippet
snippet = CustomizedSnippet(
tenant_id=account.current_tenant_id,
name=snippet_name,
description=snippet_description,
type=snippet_type.value,
icon_info=icon_info or None,
input_fields=json.dumps(input_fields) if input_fields else None,
created_by=account.id,
)
self._session.add(snippet)
self._session.flush()
# Create or update draft workflow
if workflow_data:
graph = workflow_data.get("graph", {})
environment_variables_list = workflow_data.get("environment_variables", [])
conversation_variables_list = workflow_data.get("conversation_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
snippet_service = SnippetService()
# Get existing workflow hash if exists
existing_workflow = snippet_service.get_draft_workflow(snippet=snippet)
unique_hash = existing_workflow.unique_hash if existing_workflow else None
snippet_service.sync_draft_workflow(
snippet=snippet,
graph=graph,
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
input_variables=input_fields,
)
self._session.commit()
return snippet
def export_snippet_dsl(self, snippet: CustomizedSnippet, include_secret: bool = False) -> str:
"""
Export snippet as DSL
:param snippet: CustomizedSnippet instance
:param include_secret: Whether include secret variable
:return: YAML string
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
icon_info = snippet.icon_info or {}
export_data = {
"version": CURRENT_DSL_VERSION,
"kind": "snippet",
"snippet": {
"name": snippet.name,
"description": snippet.description or "",
"type": snippet.type,
"icon_info": icon_info,
"input_fields": snippet.input_fields_list,
},
}
self._append_workflow_export_data(
export_data=export_data, snippet=snippet, workflow=workflow, include_secret=include_secret
)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
def _append_workflow_export_data(
self, *, export_data: dict, snippet: CustomizedSnippet, workflow: Workflow, include_secret: bool
) -> None:
"""
Append workflow export data
"""
workflow_dict = workflow.to_dict(include_secret=include_secret)
# Filter workspace related data from nodes
for node in workflow_dict.get("graph", {}).get("nodes", []):
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node_data.get("dataset_ids", [])
node["data"]["dataset_ids"] = [
self._encrypt_dataset_id(dataset_id=dataset_id, tenant_id=snippet.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == BuiltinNodeTypes.TOOL:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == BuiltinNodeTypes.AGENT:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
export_data["workflow"] = workflow_dict
dependencies = self._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=snippet.tenant_id, dependencies=dependencies
)
]
def _encrypt_dataset_id(self, *, dataset_id: str, tenant_id: str) -> str:
"""
Encrypt dataset ID for export
"""
# For now, just return the dataset_id as-is
# In the future, we might want to encrypt it
return dataset_id
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = self._extract_dependencies_from_workflow_graph(graph)
return dependencies
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
for node in graph.get("nodes", []):
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == BuiltinNodeTypes.TOOL:
tool_config = node_data.get("tool_configurations", {})
provider_type = tool_config.get("provider_type")
provider_name = tool_config.get("provider")
if provider_type and provider_name:
dependencies.append(f"{provider_name}/{provider_name}")
elif data_type == BuiltinNodeTypes.AGENT:
agent_parameters = node_data.get("agent_parameters", {})
tools = agent_parameters.get("tools", {}).get("value", [])
for tool in tools:
provider_type = tool.get("provider_type")
provider_name = tool.get("provider")
if provider_type and provider_name:
dependencies.append(f"{provider_name}/{provider_name}")
return dependencies

View File

@@ -1,421 +0,0 @@
"""
Service for generating snippet workflow executions.
Uses an adapter pattern to bridge CustomizedSnippet with the App-based
WorkflowAppGenerator. The adapter (_SnippetAsApp) provides the minimal App-like
interface needed by the generator, avoiding modifications to core workflow
infrastructure.
Key invariants:
- Snippets always run as WORKFLOW mode (not CHAT or ADVANCED_CHAT).
- The adapter maps snippet.id to app_id in workflow execution records.
- Snippet debugging has no rate limiting (max_active_requests = 0).
Supported execution modes:
- Full workflow run (generate): Runs the entire draft workflow as SSE stream.
- Single node run (run_draft_node): Synchronous single-step debugging for regular nodes.
- Single iteration run (generate_single_iteration): SSE stream for iteration container nodes.
- Single loop run (generate_single_loop): SSE stream for loop container nodes.
"""
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Union
from sqlalchemy.orm import make_transient
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from graphon.file.models import File
from factories import file_factory
from models import Account
from models.model import AppMode, EndUser
from models.snippet import CustomizedSnippet
from models.workflow import Workflow, WorkflowNodeExecutionModel
from services.snippet_service import SnippetService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class _SnippetAsApp:
"""
Minimal adapter that wraps a CustomizedSnippet to satisfy the App-like
interface required by WorkflowAppGenerator, WorkflowAppConfigManager,
and WorkflowService.run_draft_workflow_node.
Used properties:
- id: maps to snippet.id (stored as app_id in workflows table)
- tenant_id: maps to snippet.tenant_id
- mode: hardcoded to AppMode.WORKFLOW since snippets always run as workflows
- max_active_requests: defaults to 0 (no limit) for snippet debugging
- app_model_config_id: None (snippets don't have app model configs)
"""
id: str
tenant_id: str
mode: str
max_active_requests: int
app_model_config_id: str | None
def __init__(self, snippet: CustomizedSnippet) -> None:
self.id = snippet.id
self.tenant_id = snippet.tenant_id
self.mode = AppMode.WORKFLOW.value
self.max_active_requests = 0
self.app_model_config_id = None
class SnippetGenerateService:
"""
Service for running snippet workflow executions.
Adapts CustomizedSnippet to work with the existing App-based
WorkflowAppGenerator infrastructure, avoiding duplication of the
complex workflow execution pipeline.
"""
# Specific ID for the injected virtual Start node so it can be recognised
_VIRTUAL_START_NODE_ID = "__snippet_virtual_start__"
@classmethod
def generate(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a snippet's draft workflow.
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
then delegates execution to WorkflowAppGenerator.
If the workflow graph has no Start node, a virtual Start node is injected
in-memory so that:
1. Graph validation passes (root node must have execution_type=ROOT).
2. User inputs are processed into the variable pool by the StartNode logic.
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param args: Workflow inputs (must include "inputs" key)
:param invoke_from: Source of invocation (typically DEBUGGER)
:param streaming: Whether to stream the response
:return: Blocking response mapping or SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
# Inject a virtual Start node when the graph doesn't have one.
workflow = cls._ensure_start_node(workflow, snippet)
# Adapt snippet to App-like interface for WorkflowAppGenerator
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
)
)
@classmethod
def run_published(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
) -> Mapping[str, Any]:
"""
Run a snippet's published workflow in non-streaming (blocking) mode.
Similar to :meth:`generate` but targets the published workflow instead
of the draft, and returns the raw blocking response without SSE
wrapping. Designed for programmatic callers such as evaluation runners.
:param snippet: CustomizedSnippet instance (must be published)
:param user: Account or EndUser initiating the run
:param args: Workflow inputs (must include "inputs" key)
:param invoke_from: Source of invocation
:return: Blocking response mapping with workflow outputs
:raises ValueError: If the snippet has no published workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet)
if not workflow:
raise ValueError("No published workflow found for snippet")
# Inject a virtual Start node when the graph doesn't have one.
workflow = cls._ensure_start_node(workflow, snippet)
app_proxy = _SnippetAsApp(snippet)
response: Mapping[str, Any] = WorkflowAppGenerator().generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
)
return response
@classmethod
def ensure_start_node_for_worker(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
"""Public wrapper for worker-thread start-node injection."""
return cls._ensure_start_node(workflow, snippet)
@classmethod
def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
"""
Return *workflow* with a Start node.
If the graph already contains a Start node, the original workflow is
returned unchanged. Otherwise a virtual Start node is injected and the
workflow object is detached from the SQLAlchemy session so the in-memory
change is never flushed to the database.
"""
graph_dict = workflow.graph_dict
nodes: list[dict[str, Any]] = graph_dict.get("nodes", [])
has_start = any(node.get("data", {}).get("type") == "start" for node in nodes)
if has_start:
return workflow
modified_graph = cls._inject_virtual_start_node(
graph_dict=graph_dict,
input_fields=snippet.input_fields_list,
)
# Detach from session to prevent accidental DB persistence of the
# modified graph. All attributes remain accessible for read.
make_transient(workflow)
workflow.graph = json.dumps(modified_graph)
return workflow
@classmethod
def _inject_virtual_start_node(
cls,
graph_dict: Mapping[str, Any],
input_fields: list[dict[str, Any]],
) -> dict[str, Any]:
"""
Build a new graph dict with a virtual Start node prepended.
The virtual Start node is wired to every existing node that has no
incoming edges (i.e. the current root candidates). This guarantees:
:param graph_dict: Original graph configuration.
:param input_fields: Snippet input field definitions from
``CustomizedSnippet.input_fields_list``.
:return: New graph dict containing the virtual Start node and edges.
"""
nodes: list[dict[str, Any]] = list(graph_dict.get("nodes", []))
edges: list[dict[str, Any]] = list(graph_dict.get("edges", []))
# Identify nodes with no incoming edges.
nodes_with_incoming: set[str] = set()
for edge in edges:
target = edge.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidate_ids = [n["id"] for n in nodes if n["id"] not in nodes_with_incoming]
# Build Start node ``variables`` from snippet input fields.
start_variables: list[dict[str, Any]] = []
for field in input_fields:
var: dict[str, Any] = {
"variable": field.get("variable", ""),
"label": field.get("label", field.get("variable", "")),
"type": field.get("type", "text-input"),
"required": field.get("required", False),
"options": field.get("options", []),
}
if field.get("max_length") is not None:
var["max_length"] = field["max_length"]
start_variables.append(var)
virtual_start_node: dict[str, Any] = {
"id": cls._VIRTUAL_START_NODE_ID,
"data": {
"type": "start",
"title": "Start",
"variables": start_variables,
},
}
# Create edges from virtual Start to each root candidate.
new_edges: list[dict[str, Any]] = [
{
"source": cls._VIRTUAL_START_NODE_ID,
"sourceHandle": "source",
"target": root_id,
"targetHandle": "target",
}
for root_id in root_candidate_ids
]
return {
**graph_dict,
"nodes": [virtual_start_node, *nodes],
"edges": [*edges, *new_edges],
}
@classmethod
def run_draft_node(
cls,
snippet: CustomizedSnippet,
node_id: str,
user_inputs: Mapping[str, Any],
account: Account,
query: str = "",
files: Sequence[File] | None = None,
) -> WorkflowNodeExecutionModel:
"""
Run a single node in a snippet's draft workflow (single-step debugging).
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
parses file inputs, then delegates to WorkflowService.run_draft_workflow_node.
:param snippet: CustomizedSnippet instance
:param node_id: ID of the node to run
:param user_inputs: User input values for the node
:param account: Account initiating the run
:param query: Optional query string
:param files: Optional parsed file objects
:return: WorkflowNodeExecutionModel with execution results
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
workflow_service = WorkflowService()
return workflow_service.run_draft_workflow_node(
app_model=app_proxy, # type: ignore[arg-type]
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=account,
query=query,
files=files,
)
@classmethod
def generate_single_iteration(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
node_id: str,
args: Mapping[str, Any],
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a single iteration node in a snippet's draft workflow.
Iteration nodes are container nodes that execute their sub-graph multiple
times, producing many events. Therefore, this uses the full WorkflowAppGenerator
pipeline with SSE streaming (unlike regular single-step node run).
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param node_id: ID of the iteration node to run
:param args: Dict containing 'inputs' key with iteration input data
:param streaming: Whether to stream the response (should be True)
:return: SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
@classmethod
def generate_single_loop(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
node_id: str,
args: Any,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a single loop node in a snippet's draft workflow.
Loop nodes are container nodes that execute their sub-graph repeatedly,
producing many events. Therefore, this uses the full WorkflowAppGenerator
pipeline with SSE streaming (unlike regular single-step node run).
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param node_id: ID of the loop node to run
:param args: Pydantic model with 'inputs' attribute containing loop input data
:param streaming: Whether to stream the response (should be True)
:return: SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
node_id=node_id,
user=user,
args=args, # type: ignore[arg-type]
streaming=streaming,
)
)
@staticmethod
def parse_files(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
"""
Parse file mappings into File objects based on workflow configuration.
:param workflow: Workflow instance for file upload config
:param files: Raw file mapping dicts
:return: Parsed File objects
"""
files = files or []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config is None:
return []
return file_factory.build_from_mappings(
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
)

View File

@@ -1,573 +0,0 @@
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.node_factory import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from graphon.enums import NodeType
from graphon.variables.variables import VariableBase
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import WorkflowRunTriggeredFrom
from models.snippet import CustomizedSnippet, SnippetType
from models.workflow import (
Workflow,
WorkflowNodeExecutionModel,
WorkflowRun,
WorkflowType,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.errors.app import WorkflowHashNotEqualError
logger = logging.getLogger(__name__)
class SnippetService:
"""Service for managing customized snippets."""
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize SnippetService with repository dependencies."""
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
# --- CRUD Operations ---
@staticmethod
def get_snippets(
*,
tenant_id: str,
page: int = 1,
limit: int = 20,
keyword: str | None = None,
is_published: bool | None = None,
creators: list[str] | None = None,
) -> tuple[Sequence[CustomizedSnippet], int, bool]:
"""
Get paginated list of snippets with optional search.
:param tenant_id: Tenant ID
:param page: Page number (1-indexed)
:param limit: Number of items per page
:param keyword: Optional search keyword for name/description
:param is_published: Optional filter by published status (True/False/None for all)
:param creators: Optional filter by creator account IDs
:return: Tuple of (snippets list, total count, has_more flag)
"""
stmt = (
select(CustomizedSnippet)
.where(CustomizedSnippet.tenant_id == tenant_id)
.order_by(CustomizedSnippet.created_at.desc())
)
if keyword:
stmt = stmt.where(
CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%")
)
if is_published is not None:
stmt = stmt.where(CustomizedSnippet.is_published == is_published)
if creators:
stmt = stmt.where(CustomizedSnippet.created_by.in_(creators))
# Get total count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = db.session.scalar(count_stmt) or 0
# Apply pagination
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
snippets = list(db.session.scalars(stmt).all())
has_more = len(snippets) > limit
if has_more:
snippets = snippets[:-1]
return snippets, total, has_more
@staticmethod
def get_snippet_by_id(
*,
snippet_id: str,
tenant_id: str,
) -> CustomizedSnippet | None:
"""
Get snippet by ID with tenant isolation.
:param snippet_id: Snippet ID
:param tenant_id: Tenant ID
:return: CustomizedSnippet or None
"""
return (
db.session.query(CustomizedSnippet)
.where(
CustomizedSnippet.id == snippet_id,
CustomizedSnippet.tenant_id == tenant_id,
)
.first()
)
@staticmethod
def create_snippet(
*,
tenant_id: str,
name: str,
description: str | None,
snippet_type: SnippetType,
icon_info: dict | None,
input_fields: list[dict] | None,
account: Account,
) -> CustomizedSnippet:
"""
Create a new snippet.
:param tenant_id: Tenant ID
:param name: Snippet name (must be unique per tenant)
:param description: Snippet description
:param snippet_type: Type of snippet (node or group)
:param icon_info: Icon information
:param input_fields: Input field definitions
:param account: Creator account
:return: Created CustomizedSnippet
:raises ValueError: If name already exists
"""
# Check if name already exists for this tenant
existing = (
db.session.query(CustomizedSnippet)
.where(
CustomizedSnippet.tenant_id == tenant_id,
CustomizedSnippet.name == name,
)
.first()
)
if existing:
raise ValueError(f"Snippet with name '{name}' already exists")
snippet = CustomizedSnippet(
tenant_id=tenant_id,
name=name,
description=description or "",
type=snippet_type.value,
icon_info=icon_info,
input_fields=json.dumps(input_fields) if input_fields else None,
created_by=account.id,
)
db.session.add(snippet)
db.session.commit()
return snippet
@staticmethod
def update_snippet(
*,
session: Session,
snippet: CustomizedSnippet,
account_id: str,
data: dict,
) -> CustomizedSnippet:
"""
Update snippet attributes.
:param session: Database session
:param snippet: Snippet to update
:param account_id: ID of account making the update
:param data: Dictionary of fields to update
:return: Updated CustomizedSnippet
"""
if "name" in data:
# Check if new name already exists for this tenant
existing = (
session.query(CustomizedSnippet)
.where(
CustomizedSnippet.tenant_id == snippet.tenant_id,
CustomizedSnippet.name == data["name"],
CustomizedSnippet.id != snippet.id,
)
.first()
)
if existing:
raise ValueError(f"Snippet with name '{data['name']}' already exists")
snippet.name = data["name"]
if "description" in data:
snippet.description = data["description"]
if "icon_info" in data:
snippet.icon_info = data["icon_info"]
snippet.updated_by = account_id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
session.add(snippet)
return snippet
@staticmethod
def delete_snippet(
*,
session: Session,
snippet: CustomizedSnippet,
) -> bool:
"""
Delete a snippet.
:param session: Database session
:param snippet: Snippet to delete
:return: True if deleted successfully
"""
session.delete(snippet)
return True
# --- Workflow Operations ---
def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
"""
Get draft workflow for snippet.
:param snippet: CustomizedSnippet instance
:return: Draft Workflow or None
"""
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version == "draft",
)
.first()
)
return workflow
def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
"""
Get published workflow for snippet.
:param snippet: CustomizedSnippet instance
:return: Published Workflow or None
"""
if not snippet.workflow_id:
return None
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.id == snippet.workflow_id,
)
.first()
)
return workflow
def sync_draft_workflow(
self,
*,
snippet: CustomizedSnippet,
graph: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
input_variables: list[dict] | None = None,
) -> Workflow:
"""
Sync draft workflow for snippet.
:param snippet: CustomizedSnippet instance
:param graph: Workflow graph configuration
:param unique_hash: Hash for conflict detection
:param account: Account making the change
:param environment_variables: Environment variables
:param conversation_variables: Conversation variables
:param input_variables: Input variables for snippet
:return: Synced Workflow
:raises WorkflowHashNotEqualError: If hash mismatch
"""
workflow = self.get_draft_workflow(snippet=snippet)
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# Create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
features="{}",
type=WorkflowType.SNIPPET.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
db.session.add(workflow)
db.session.flush()
else:
# Update existing draft workflow
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
# Update snippet's input_fields if provided
if input_variables is not None:
snippet.input_fields = json.dumps(input_variables)
snippet.updated_by = account.id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return workflow
def publish_workflow(
self,
*,
session: Session,
snippet: CustomizedSnippet,
account: Account,
) -> Workflow:
"""
Publish the draft workflow as a new version.
:param session: Database session
:param snippet: CustomizedSnippet instance
:param account: Account making the change
:return: Published Workflow
:raises ValueError: If no draft workflow exists
"""
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow:
raise ValueError("No valid workflow found.")
# Create new published workflow
workflow = Workflow.new(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
type=draft_workflow.type,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
marked_name="",
marked_comment="",
)
session.add(workflow)
# Update snippet version
snippet.version += 1
snippet.is_published = True
snippet.workflow_id = workflow.id
snippet.updated_by = account.id
session.add(snippet)
return workflow
def get_all_published_workflows(
self,
*,
session: Session,
snippet: CustomizedSnippet,
page: int,
limit: int,
) -> tuple[Sequence[Workflow], bool]:
"""
Get all published workflow versions for snippet.
:param session: Database session
:param snippet: CustomizedSnippet instance
:param page: Page number
:param limit: Items per page
:return: Tuple of (workflows list, has_more flag)
"""
if not snippet.workflow_id:
return [], False
stmt = (
select(Workflow)
.where(
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version != "draft",
)
.order_by(Workflow.version.desc())
.limit(limit + 1)
.offset((page - 1) * limit)
)
workflows = list(session.scalars(stmt).all())
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
# --- Default Block Configs ---
def get_default_block_configs(self) -> list[dict]:
"""
Get default block configurations for all node types.
:return: List of default configurations
"""
default_block_configs: list[dict[str, Any]] = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(dict(default_config))
return default_block_configs
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
"""
Get default config for specific node type.
:param node_type: Node type string
:param filters: Optional filters
:return: Default configuration or None
"""
node_type_enum = NodeType(node_type)
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
# --- Workflow Run Operations ---
def get_snippet_workflow_runs(
self,
*,
snippet: CustomizedSnippet,
args: dict,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs for snippet.
:param snippet: CustomizedSnippet instance
:param args: Request arguments (last_id, limit)
:return: InfiniteScrollPagination result
"""
limit = int(args.get("limit", 20))
last_id = args.get("last_id")
triggered_from_values = [
WorkflowRunTriggeredFrom.DEBUGGING,
]
return self._workflow_run_repo.get_paginated_workflow_runs(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
triggered_from=triggered_from_values,
limit=limit,
last_id=last_id,
)
def get_snippet_workflow_run(
self,
*,
snippet: CustomizedSnippet,
run_id: str,
) -> WorkflowRun | None:
"""
Get workflow run details.
:param snippet: CustomizedSnippet instance
:param run_id: Workflow run ID
:return: WorkflowRun or None
"""
return self._workflow_run_repo.get_workflow_run_by_id(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
run_id=run_id,
)
def get_snippet_workflow_run_node_executions(
self,
*,
snippet: CustomizedSnippet,
run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get workflow run node execution list.
:param snippet: CustomizedSnippet instance
:param run_id: Workflow run ID
:return: List of WorkflowNodeExecutionModel
"""
workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
return []
node_executions = self._node_execution_service_repo.get_executions_by_workflow_run(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
workflow_run_id=workflow_run.id,
)
return node_executions
# --- Node Execution Operations ---
def get_snippet_node_last_run(
self,
*,
snippet: CustomizedSnippet,
workflow: Workflow,
node_id: str,
) -> WorkflowNodeExecutionModel | None:
"""
Get the most recent execution for a specific node in a snippet workflow.
:param snippet: CustomizedSnippet instance
:param workflow: Workflow instance
:param node_id: Node identifier
:return: WorkflowNodeExecutionModel or None
"""
return self._node_execution_service_repo.get_node_last_execution(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
workflow_id=workflow.id,
node_id=node_id,
)
# --- Use Count ---
@staticmethod
def increment_use_count(
*,
session: Session,
snippet: CustomizedSnippet,
) -> None:
"""
Increment the use_count when snippet is used.
:param session: Database session
:param snippet: CustomizedSnippet instance
"""
snippet.use_count += 1
session.add(snippet)

View File

@@ -24,28 +24,17 @@ class LogView:
"""Lightweight wrapper for WorkflowAppLog with computed details.
- Exposes `details_` for marshalling to `details` in API response
- Exposes `evaluation_` for marshalling evaluation metrics in API response
- Proxies all other attributes to the underlying `WorkflowAppLog`
"""
def __init__(
self,
log: WorkflowAppLog,
details: LogViewDetails | None,
evaluation: list[dict] | None = None,
):
def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None):
self.log = log
self.details_ = details
self.evaluation_ = evaluation
@property
def details(self) -> LogViewDetails | None:
return self.details_
@property
def evaluation(self) -> list[dict] | None:
return self.evaluation_
def __getattr__(self, name):
return getattr(self.log, name)
@@ -182,20 +171,12 @@ class WorkflowAppService:
# Execute query and get items
if detail:
rows = session.execute(offset_stmt).all()
logs_with_details = [
(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)})
items = [
LogView(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)})
for log, meta_val in rows
]
else:
logs_with_details = [(log, None) for log in session.scalars(offset_stmt).all()]
workflow_run_ids = [log.workflow_run_id for log, _ in logs_with_details]
eval_map = self._batch_query_evaluation_metrics(session, workflow_run_ids)
items = [
LogView(log, details, evaluation=eval_map.get(log.workflow_run_id))
for log, details in logs_with_details
]
items = [LogView(log, None) for log in session.scalars(offset_stmt).all()]
return {
"page": page,
"limit": limit,
@@ -277,45 +258,6 @@ class WorkflowAppService:
"data": items,
}
@staticmethod
def _batch_query_evaluation_metrics(
session: Session,
workflow_run_ids: list[str],
) -> dict[str, list[dict[str, Any]]]:
"""Return evaluation metrics keyed by workflow_run_id.
Only returns metrics from completed evaluation runs. If a workflow
run was not part of any evaluation (or the evaluation has not
completed), it will be absent from the result dict.
"""
from models.evaluation import EvaluationRun, EvaluationRunItem, EvaluationRunStatus
if not workflow_run_ids:
return {}
non_null_ids = [wid for wid in workflow_run_ids if wid]
if not non_null_ids:
return {}
stmt = (
select(EvaluationRunItem.workflow_run_id, EvaluationRunItem.metrics)
.join(EvaluationRun, EvaluationRun.id == EvaluationRunItem.evaluation_run_id)
.where(
EvaluationRunItem.workflow_run_id.in_(non_null_ids),
EvaluationRun.status == EvaluationRunStatus.COMPLETED,
)
)
rows = session.execute(stmt).all()
result: dict[str, list[dict[str, Any]]] = {}
for wf_run_id, metrics_json in rows:
if wf_run_id and metrics_json:
parsed: list[dict[str, Any]] = json.loads(metrics_json)
existing = result.get(wf_run_id, [])
existing.extend(parsed)
result[wf_run_id] = existing
return result
def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]:
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
if not metadata:

View File

@@ -27,6 +27,7 @@ from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_deb
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
from enums.cloud_plan import CloudPlan
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
@@ -849,6 +850,13 @@ class WorkflowService:
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
session.commit()
enqueue_draft_node_execution_trace(
execution=workflow_node_execution,
outputs=outputs,
workflow_execution_id=None,
user_id=account.id,
)
return workflow_node_execution
def get_human_input_form_preview(

View File

@@ -0,0 +1,52 @@
"""Celery worker for enterprise metric/log telemetry events.
This module defines the Celery task that processes telemetry envelopes
from the enterprise_telemetry queue. It deserializes envelopes and
dispatches them to the EnterpriseMetricHandler.
"""
import json
import logging
from celery import shared_task
from enterprise.telemetry.contracts import TelemetryEnvelope
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
logger = logging.getLogger(__name__)
@shared_task(queue="enterprise_telemetry")
def process_enterprise_telemetry(envelope_json: str) -> None:
"""Process enterprise metric/log telemetry envelope.
This task is enqueued by the TelemetryGateway for metric/log-only
events. It deserializes the envelope and dispatches to the handler.
Best-effort processing: logs errors but never raises, to avoid
failing user requests due to telemetry issues.
Args:
envelope_json: JSON-serialized TelemetryEnvelope.
"""
try:
# Deserialize envelope
envelope_dict = json.loads(envelope_json)
envelope = TelemetryEnvelope.model_validate(envelope_dict)
# Process through handler
handler = EnterpriseMetricHandler()
handler.handle(envelope)
logger.debug(
"Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s",
envelope.tenant_id,
envelope.event_id,
envelope.case,
)
except Exception:
# Best-effort: log and drop on error, never fail user request
logger.warning(
"Failed to process enterprise telemetry envelope, dropping event",
exc_info=True,
)

View File

@@ -1,454 +0,0 @@
import io
import json
import logging
from typing import Any
from celery import shared_task
from openpyxl import Workbook
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
from openpyxl.utils import get_column_letter
from configs import dify_config
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
EvaluationCategory,
EvaluationDatasetInput,
EvaluationItemResult,
EvaluationRunData,
)
from core.evaluation.entities.judgment_entity import JudgmentConfig
from core.evaluation.evaluation_manager import EvaluationManager
from core.evaluation.runners.agent_evaluation_runner import AgentEvaluationRunner
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from core.evaluation.runners.llm_evaluation_runner import LLMEvaluationRunner
from core.evaluation.runners.retrieval_evaluation_runner import RetrievalEvaluationRunner
from core.evaluation.runners.snippet_evaluation_runner import SnippetEvaluationRunner
from core.evaluation.runners.workflow_evaluation_runner import WorkflowEvaluationRunner
from graphon.node_events import NodeRunResult
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.evaluation import EvaluationRun, EvaluationRunStatus
from models.model import UploadFile
from services.evaluation_service import EvaluationService
logger = logging.getLogger(__name__)
@shared_task(queue="evaluation")
def run_evaluation(run_data_dict: dict[str, Any]) -> None:
"""Celery task for running evaluations asynchronously.
Workflow:
1. Deserialize EvaluationRunData
2. Update status to RUNNING
3. Select appropriate Runner based on evaluation_category
4. Execute runner.run() which handles target execution + metric computation
5. Generate result XLSX
6. Update EvaluationRun status to COMPLETED
"""
run_data = EvaluationRunData.model_validate(run_data_dict)
with db.engine.connect() as connection:
from sqlalchemy.orm import Session
session = Session(bind=connection)
try:
_execute_evaluation(session, run_data)
except Exception as e:
logger.exception("Evaluation run %s failed", run_data.evaluation_run_id)
_mark_run_failed(session, run_data.evaluation_run_id, str(e))
finally:
session.close()
def _execute_evaluation(session: Any, run_data: EvaluationRunData) -> None:
"""Core evaluation execution logic."""
evaluation_run = session.query(EvaluationRun).filter_by(id=run_data.evaluation_run_id).first()
if not evaluation_run:
logger.error("EvaluationRun %s not found", run_data.evaluation_run_id)
return
# Check if cancelled
if evaluation_run.status == EvaluationRunStatus.CANCELLED:
logger.info("EvaluationRun %s was cancelled", run_data.evaluation_run_id)
return
# Get evaluation instance
evaluation_instance = EvaluationManager.get_evaluation_instance()
if evaluation_instance is None:
raise ValueError("Evaluation framework not configured")
if run_data.target_type == "dataset":
results: list[EvaluationItemResult] = _execute_retrieval_test(
session=session,
evaluation_run=evaluation_run,
run_data=run_data,
evaluation_instance=evaluation_instance,
)
else:
evaluation_service = EvaluationService()
node_run_result_mapping_list, workflow_run_ids = evaluation_service.execute_targets(
tenant_id=run_data.tenant_id,
target_type=run_data.target_type,
target_id=run_data.target_id,
input_list=run_data.input_list,
)
results = _execute_evaluation_runner(
session=session,
run_data=run_data,
evaluation_instance=evaluation_instance,
node_run_result_mapping_list=node_run_result_mapping_list,
)
_backfill_workflow_run_ids(
session=session,
evaluation_run_id=run_data.evaluation_run_id,
input_list=run_data.input_list,
workflow_run_ids=workflow_run_ids,
)
# Compute summary metrics
metrics_summary = _compute_metrics_summary(results, run_data.judgment_config)
# Generate result XLSX
result_xlsx = _generate_result_xlsx(run_data.input_list, results)
# Store result file
result_file_id = _store_result_file(run_data.tenant_id, run_data.evaluation_run_id, result_xlsx, session)
# Update run to completed
evaluation_run: EvaluationRun = session.query(EvaluationRun).filter_by(id=run_data.evaluation_run_id).first()
if evaluation_run:
evaluation_run.status = EvaluationRunStatus.COMPLETED
evaluation_run.completed_at = naive_utc_now()
evaluation_run.metrics_summary = json.dumps(metrics_summary)
if result_file_id:
evaluation_run.result_file_id = result_file_id
session.commit()
logger.info("Evaluation run %s completed successfully", run_data.evaluation_run_id)
def _execute_evaluation_runner(
session: Any,
run_data: EvaluationRunData,
evaluation_instance: BaseEvaluationInstance,
node_run_result_mapping_list: list[dict[str, NodeRunResult]],
) -> list[EvaluationItemResult]:
"""Execute the evaluation runner."""
default_metrics = run_data.default_metrics
customized_metrics = run_data.customized_metrics
results: list[EvaluationItemResult] = []
for default_metric in default_metrics:
for node_info in default_metric.node_info_list:
node_run_result_list: list[NodeRunResult] = []
for node_run_result_mapping in node_run_result_mapping_list:
node_run_result = node_run_result_mapping.get(node_info.node_id)
if node_run_result is not None:
node_run_result_list.append(node_run_result)
if node_run_result_list:
runner = _create_runner(EvaluationCategory(node_info.type), evaluation_instance, session)
results.extend(
runner.run(
evaluation_run_id=run_data.evaluation_run_id,
tenant_id=run_data.tenant_id,
target_id=run_data.target_id,
target_type=run_data.target_type,
default_metric=default_metric,
customized_metrics=None,
model_provider=run_data.evaluation_model_provider,
model_name=run_data.evaluation_model,
node_run_result_list=node_run_result_list,
judgment_config=run_data.judgment_config,
input_list=run_data.input_list,
)
)
if customized_metrics:
runner = _create_runner(EvaluationCategory.WORKFLOW, evaluation_instance, session)
results.extend(
runner.run(
evaluation_run_id=run_data.evaluation_run_id,
tenant_id=run_data.tenant_id,
target_id=run_data.target_id,
target_type=run_data.target_type,
default_metric=None,
customized_metrics=customized_metrics,
node_run_result_list=None,
node_run_result_mapping_list=node_run_result_mapping_list,
judgment_config=run_data.judgment_config,
input_list=run_data.input_list,
)
)
return results
def _create_runner(
category: EvaluationCategory,
evaluation_instance: BaseEvaluationInstance,
session: Any,
) -> BaseEvaluationRunner:
"""Create the appropriate runner for the evaluation category."""
match category:
case EvaluationCategory.LLM:
return LLMEvaluationRunner(evaluation_instance, session)
case EvaluationCategory.RETRIEVAL | EvaluationCategory.KNOWLEDGE_BASE:
return RetrievalEvaluationRunner(evaluation_instance, session)
case EvaluationCategory.AGENT:
return AgentEvaluationRunner(evaluation_instance, session)
case EvaluationCategory.WORKFLOW:
return WorkflowEvaluationRunner(evaluation_instance, session)
case EvaluationCategory.SNIPPET:
return SnippetEvaluationRunner(evaluation_instance, session)
case _:
raise ValueError(f"Unknown evaluation category: {category}")
def _execute_retrieval_test(
session: Any,
evaluation_run: EvaluationRun,
run_data: EvaluationRunData,
evaluation_instance: BaseEvaluationInstance,
) -> list[EvaluationItemResult]:
"""Execute knowledge base retrieval for all items, then evaluate metrics.
Unlike the workflow-based path, there are no workflow nodes to traverse.
Hit testing is run directly for each dataset item and the results are fed
straight into :class:`RetrievalEvaluationRunner`.
"""
node_run_result_list = EvaluationService.execute_retrieval_test_targets(
dataset_id=run_data.target_id,
account_id=evaluation_run.created_by,
input_list=run_data.input_list,
)
results: list[EvaluationItemResult] = []
runner = RetrievalEvaluationRunner(evaluation_instance, session)
results.extend(
runner.run(
evaluation_run_id=run_data.evaluation_run_id,
tenant_id=run_data.tenant_id,
target_id=run_data.target_id,
target_type=run_data.target_type,
default_metric=None,
model_provider=run_data.evaluation_model_provider,
model_name=run_data.evaluation_model,
node_run_result_list=node_run_result_list,
judgment_config=run_data.judgment_config,
input_list=run_data.input_list,
)
)
return results
def _backfill_workflow_run_ids(
session: Any,
evaluation_run_id: str,
input_list: list[EvaluationDatasetInput],
workflow_run_ids: list[str | None],
) -> None:
"""Set ``workflow_run_id`` on items that were created by the runner."""
from models.evaluation import EvaluationRunItem
for item, wf_run_id in zip(input_list, workflow_run_ids):
if not wf_run_id:
continue
run_item = (
session.query(EvaluationRunItem)
.filter_by(evaluation_run_id=evaluation_run_id, item_index=item.index)
.first()
)
if run_item:
run_item.workflow_run_id = wf_run_id
session.commit()
def _mark_run_failed(session: Any, run_id: str, error: str) -> None:
"""Mark an evaluation run as failed."""
try:
evaluation_run = session.query(EvaluationRun).filter_by(id=run_id).first()
if evaluation_run:
evaluation_run.status = EvaluationRunStatus.FAILED
evaluation_run.error = error[:2000] # Truncate error
evaluation_run.completed_at = naive_utc_now()
session.commit()
except Exception:
logger.exception("Failed to mark run %s as failed", run_id)
def _compute_metrics_summary(
results: list[EvaluationItemResult],
judgment_config: JudgmentConfig | None,
) -> dict[str, Any]:
"""Compute aggregate metric and judgment summaries for an evaluation run.
Metric statistics are calculated from successful item results only. When a
judgment config is present, the summary also reports how many successful
items passed or failed the configured judgment rules.
"""
summary: dict[str, Any] = {}
if judgment_config is not None and judgment_config.conditions:
evaluated_results: list[EvaluationItemResult] = [
result for result in results if result.error is None and result.metrics
]
passed_items = sum(1 for result in evaluated_results if result.judgment.passed)
evaluated_items = len(evaluated_results)
summary["_judgment"] = {
"enabled": True,
"logical_operator": judgment_config.logical_operator,
"configured_conditions": len(judgment_config.conditions),
"evaluated_items": evaluated_items,
"passed_items": passed_items,
"failed_items": evaluated_items - passed_items,
"pass_rate": passed_items / evaluated_items if evaluated_items else 0.0,
}
return summary
def _generate_result_xlsx(
input_list: list[EvaluationDatasetInput],
results: list[EvaluationItemResult],
) -> bytes:
"""Generate result XLSX with input data, actual output, metric scores, and judgment."""
wb = Workbook()
ws = wb.active
if ws is None:
ws = wb.create_sheet("Evaluation Results")
ws.title = "Evaluation Results"
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin"),
)
# Collect all metric names
all_metric_names: list[str] = []
for result in results:
for metric in result.metrics:
if metric.name not in all_metric_names:
all_metric_names.append(metric.name)
# Collect all input keys
input_keys: list[str] = []
for item in input_list:
for key in item.inputs:
if key not in input_keys:
input_keys.append(key)
# Include judgment column only when at least one result has judgment conditions evaluated
has_judgment = any(bool(r.judgment.condition_results) for r in results)
# Build headers
judgment_headers = ["judgment"] if has_judgment else []
headers = (
["index"] + input_keys + ["expected_output", "actual_output"] + all_metric_names + judgment_headers + ["error"]
)
# Write header row
for col_idx, header in enumerate(headers, start=1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
# Set column widths
ws.column_dimensions["A"].width = 10
for col_idx in range(2, len(headers) + 1):
ws.column_dimensions[get_column_letter(col_idx)].width = 25
# Build result lookup
result_by_index = {r.index: r for r in results}
# Write data rows
for row_idx, item in enumerate(input_list, start=2):
result = result_by_index.get(item.index)
col = 1
# Index
ws.cell(row=row_idx, column=col, value=item.index).border = thin_border
col += 1
# Input values
for key in input_keys:
val = item.inputs.get(key, "")
ws.cell(row=row_idx, column=col, value=str(val)).border = thin_border
col += 1
# Expected output
ws.cell(row=row_idx, column=col, value=item.expected_output or "").border = thin_border
col += 1
# Actual output
ws.cell(row=row_idx, column=col, value=result.actual_output if result else "").border = thin_border
col += 1
# Metric scores
metric_scores = {m.name: m.value for m in result.metrics} if result else {}
for metric_name in all_metric_names:
score = metric_scores.get(metric_name)
ws.cell(row=row_idx, column=col, value=score if score is not None else "").border = thin_border
col += 1
# Judgment result
if has_judgment:
if result and result.judgment.condition_results:
judgment_value = "Pass" if result.judgment.passed else "Fail"
else:
judgment_value = ""
ws.cell(row=row_idx, column=col, value=judgment_value).border = thin_border
col += 1
# Error
ws.cell(row=row_idx, column=col, value=result.error if result else "").border = thin_border
output = io.BytesIO()
wb.save(output)
output.seek(0)
return output.getvalue()
def _store_result_file(
tenant_id: str,
run_id: str,
xlsx_content: bytes,
session: Any,
) -> str | None:
"""Store result XLSX file and return the UploadFile ID."""
try:
from extensions.ext_storage import storage
from libs.uuid_utils import uuidv7
filename = f"evaluation-result-{run_id[:8]}.xlsx"
storage_key = f"evaluation_results/{tenant_id}/{str(uuidv7())}.xlsx"
storage.save(storage_key, xlsx_content)
upload_file: UploadFile = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=storage_key,
name=filename,
size=len(xlsx_content),
extension="xlsx",
mime_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="system",
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file)
session.commit()
return upload_file.id
except Exception:
logger.exception("Failed to store result file for run %s", run_id)
return None

View File

@@ -39,17 +39,36 @@ def process_trace_tasks(file_info):
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
try:
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
if is_ee_telemetry_enabled():
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
try:
EnterpriseOtelTrace().trace(trace_info)
except Exception:
logger.exception("Enterprise trace failed for app_id: %s", app_id)
if trace_instance:
with current_app.app_context():
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logger.info("Processing trace tasks success, app_id: %s", app_id)
except Exception as e:
logger.info("error:\n\n\n%s\n\n\n\n", e)
logger.exception("Processing trace tasks failed, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logger.info("Processing trace tasks failed, app_id: %s", app_id)
finally:
storage.delete(file_path)
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)

View File

@@ -125,7 +125,7 @@ def _create_node_execution_from_domain(
else:
node_execution.execution_metadata = "{}"
node_execution.status = execution.status.value
node_execution.status = execution.status
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.created_by_role = creator_user_role
@@ -159,7 +159,7 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode
node_execution.execution_metadata = "{}"
# Update other fields
node_execution.status = execution.status.value
node_execution.status = execution.status
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.finished_at = execution.finished_at

View File

@@ -1,11 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
from datetime import timedelta
from decimal import Decimal
from uuid import uuid4
from graphon.nodes.human_input.entities import FormDefinition, UserAction
from libs.datetime_utils import naive_utc_now
from models.account import Account, Tenant, TenantAccountJoin
from models.enums import ConversationFromSource, InvokeFrom
from models.execution_extra_content import HumanInputContent
@@ -117,7 +118,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
inputs=[],
user_actions=[UserAction(id=action_id, title=action_text)],
rendered_content="Rendered block",
expiration_time=datetime.utcnow() + timedelta(days=1),
expiration_time=naive_utc_now() + timedelta(days=1),
node_title=node_title,
display_in_ui=True,
)
@@ -129,7 +130,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
form_definition=form_definition.model_dump_json(),
rendered_content="Rendered block",
status=HumanInputFormStatus.SUBMITTED,
expiration_time=datetime.utcnow() + timedelta(days=1),
expiration_time=naive_utc_now() + timedelta(days=1),
selected_action_id=action_id,
)
db_session.add(form)

View File

@@ -7,7 +7,7 @@ from __future__ import annotations
from collections.abc import Generator
from dataclasses import dataclass
from datetime import datetime, timedelta
from datetime import timedelta
from decimal import Decimal
from uuid import uuid4
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session, sessionmaker
from graphon.nodes.human_input.entities import FormDefinition, UserAction
from graphon.nodes.human_input.enums import HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import ConversationFromSource, InvokeFrom
from models.execution_extra_content import ExecutionExtraContent, HumanInputContent
@@ -174,7 +175,7 @@ def _create_submitted_form(
action_title: str = "Approve",
node_title: str = "Approval",
) -> HumanInputForm:
expiration_time = datetime.utcnow() + timedelta(days=1)
expiration_time = naive_utc_now() + timedelta(days=1)
form_definition = FormDefinition(
form_content="content",
inputs=[],
@@ -207,7 +208,7 @@ def _create_waiting_form(
workflow_run_id: str,
default_values: dict | None = None,
) -> HumanInputForm:
expiration_time = datetime.utcnow() + timedelta(days=1)
expiration_time = naive_utc_now() + timedelta(days=1)
form_definition = FormDefinition(
form_content="content",
inputs=[],

View File

@@ -0,0 +1,289 @@
from __future__ import annotations
import json
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_service import ApiKeyAuthService
class TestApiKeyAuthService:
@pytest.fixture
def tenant_id(self) -> str:
return str(uuid4())
@pytest.fixture
def category(self) -> str:
return "search"
@pytest.fixture
def provider(self) -> str:
return "google"
@pytest.fixture
def mock_credentials(self) -> dict:
return {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}}
@pytest.fixture
def mock_args(self, category, provider, mock_credentials) -> dict:
return {"category": category, "provider": provider, "credentials": mock_credentials}
def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False):
binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id,
category=category,
provider=provider,
credentials=json.dumps(credentials, ensure_ascii=False) if credentials else None,
disabled=disabled,
)
db_session.add(binding)
db_session.commit()
return binding
def test_get_provider_auth_list_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
):
self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
assert len(result) >= 1
tenant_results = [r for r in result if r.tenant_id == tenant_id]
assert len(tenant_results) == 1
assert tenant_results[0].provider == provider
def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id):
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
tenant_results = [r for r in result if r.tenant_id == tenant_id]
assert tenant_results == []
def test_get_provider_auth_list_filters_disabled(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
):
self._create_binding(
db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True
)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
tenant_results = [r for r in result if r.tenant_id == tenant_id]
assert tenant_results == []
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_success(
self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123"
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
mock_factory.assert_called_once()
mock_auth_instance.validate_credentials.assert_called_once()
mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, "test_secret_key_123")
db_session_with_containers.expire_all()
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all()
assert len(bindings) == 1
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_validation_failed(
self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = False
mock_factory.return_value = mock_auth_instance
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
db_session_with_containers.expire_all()
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all()
assert len(bindings) == 0
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encrypts_api_key(
self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123"
original_key = mock_args["credentials"]["config"]["api_key"]
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123"
assert mock_args["credentials"]["config"]["api_key"] != original_key
mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key)
def test_get_auth_credentials_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials
):
self._create_binding(
db_session_with_containers,
tenant_id=tenant_id,
category=category,
provider=provider,
credentials=mock_credentials,
)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
assert result == mock_credentials
def test_get_auth_credentials_not_found(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
):
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
assert result is None
def test_get_auth_credentials_json_parsing(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
):
special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
self._create_binding(
db_session_with_containers,
tenant_id=tenant_id,
category=category,
provider=provider,
credentials=special_credentials,
)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
assert result == special_credentials
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
def test_delete_provider_auth_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
):
binding = self._create_binding(
db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider
)
binding_id = binding.id
db_session_with_containers.expire_all()
ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id)
db_session_with_containers.expire_all()
remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first()
assert remaining is None
def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id):
# Should not raise when binding not found
ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4()))
def test_validate_api_key_auth_args_success(self, mock_args):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_missing_category(self, mock_args):
del mock_args["category"]
with pytest.raises(ValueError, match="category is required"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_empty_category(self, mock_args):
mock_args["category"] = ""
with pytest.raises(ValueError, match="category is required"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_missing_provider(self, mock_args):
del mock_args["provider"]
with pytest.raises(ValueError, match="provider is required"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_empty_provider(self, mock_args):
mock_args["provider"] = ""
with pytest.raises(ValueError, match="provider is required"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_missing_credentials(self, mock_args):
del mock_args["credentials"]
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_empty_credentials(self, mock_args):
mock_args["credentials"] = None
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_invalid_credentials_type(self, mock_args):
mock_args["credentials"] = "not_a_dict"
with pytest.raises(ValueError, match="credentials must be a dictionary"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_missing_auth_type(self, mock_args):
del mock_args["credentials"]["auth_type"]
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
def test_validate_api_key_auth_args_empty_auth_type(self, mock_args):
mock_args["credentials"]["auth_type"] = ""
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
@pytest.mark.parametrize(
"malicious_input",
[
"<script>alert('xss')</script>",
"'; DROP TABLE users; --",
"../../../etc/passwd",
"\\x00\\x00",
"A" * 10000,
],
)
def test_validate_api_key_auth_args_malicious_input(self, malicious_input, mock_args):
mock_args["category"] = malicious_input
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_database_error_handling(
self, mock_encrypter, mock_factory, flask_app_with_containers, tenant_id, mock_args
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
mock_encrypter.encrypt_token.return_value = "encrypted_key"
with patch("services.auth.api_key_auth_service.db.session") as mock_session:
mock_session.commit.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"):
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_factory_exception(self, mock_factory, tenant_id, mock_args):
mock_factory.side_effect = Exception("Factory error")
with pytest.raises(Exception, match="Factory error"):
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, tenant_id, mock_args):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
with pytest.raises(Exception, match="Encryption error"):
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
def test_validate_api_key_auth_args_none_input(self):
with pytest.raises(TypeError):
ApiKeyAuthService.validate_api_key_auth_args(None)
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self, mock_args):
mock_args["credentials"]["auth_type"] = ["api_key"]
ApiKeyAuthService.validate_api_key_auth_args(mock_args)

View File

@@ -0,0 +1,264 @@
"""
API Key Authentication System Integration Tests
"""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch
from uuid import uuid4
import httpx
import pytest
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService
from services.auth.auth_type import AuthType
class TestAuthIntegration:
@pytest.fixture
def tenant_id_1(self) -> str:
return str(uuid4())
@pytest.fixture
def tenant_id_2(self) -> str:
return str(uuid4())
@pytest.fixture
def category(self) -> str:
return "search"
@pytest.fixture
def firecrawl_credentials(self) -> dict:
return {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}
@pytest.fixture
def jina_credentials(self) -> dict:
return {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}}
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_end_to_end_auth_flow(
self,
mock_encrypt,
mock_http,
flask_app_with_containers,
db_session_with_containers,
tenant_id_1,
category,
firecrawl_credentials,
):
mock_http.return_value = self._create_success_response()
mock_encrypt.return_value = "encrypted_fc_test_key_123"
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
mock_http.assert_called_once()
call_args = mock_http.call_args
assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0]
assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123"
mock_encrypt.assert_called_once_with(tenant_id_1, "fc_test_key_123")
db_session_with_containers.expire_all()
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all()
assert len(bindings) == 1
assert bindings[0].provider == AuthType.FIRECRAWL
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_cross_component_integration(self, mock_http, firecrawl_credentials):
mock_http.return_value = self._create_success_response()
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials)
result = factory.validate_credentials()
assert result is True
mock_http.assert_called_once()
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.jina.jina.httpx.post")
def test_multi_tenant_isolation(
self,
mock_jina_http,
mock_fc_http,
mock_encrypt,
flask_app_with_containers,
db_session_with_containers,
tenant_id_1,
tenant_id_2,
category,
firecrawl_credentials,
jina_credentials,
):
mock_fc_http.return_value = self._create_success_response()
mock_jina_http.return_value = self._create_success_response()
mock_encrypt.return_value = "encrypted_key"
args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
ApiKeyAuthService.create_provider_auth(tenant_id_1, args1)
args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials}
ApiKeyAuthService.create_provider_auth(tenant_id_2, args2)
db_session_with_containers.expire_all()
result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1)
result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2)
assert len(result1) == 1
assert result1[0].tenant_id == tenant_id_1
assert len(result2) == 1
assert result2[0].tenant_id == tenant_id_2
def test_cross_tenant_access_prevention(
self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category
):
result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL)
assert result is None
def test_sensitive_data_protection(self):
credentials_with_secrets = {
"auth_type": "bearer",
"config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"},
}
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets)
factory_str = str(factory)
assert "super_secret_key_do_not_log" not in factory_str
assert "another_secret" not in factory_str
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token", return_value="encrypted_key")
def test_concurrent_creation_safety(
self,
mock_encrypt,
mock_http,
flask_app_with_containers,
db_session_with_containers,
tenant_id_1,
category,
firecrawl_credentials,
):
app = flask_app_with_containers
mock_http.return_value = self._create_success_response()
results = []
exceptions = []
def create_auth():
try:
with app.app_context():
thread_args = {
"category": category,
"provider": AuthType.FIRECRAWL,
"credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}},
}
ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args)
results.append("success")
except Exception as e:
exceptions.append(e)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(create_auth) for _ in range(5)]
for future in futures:
future.result()
assert len(results) == 5
assert len(exceptions) == 0
@pytest.mark.parametrize(
"invalid_input",
[
None,
{},
{"auth_type": "bearer"},
{"auth_type": "bearer", "config": {}},
],
)
def test_invalid_input_boundary(self, invalid_input):
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_http_error_handling(self, mock_http, firecrawl_credentials):
mock_response = Mock()
mock_response.status_code = 401
mock_response.text = '{"error": "Unauthorized"}'
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
mock_http.return_value = mock_response
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials)
with pytest.raises((httpx.HTTPError, Exception)):
factory.validate_credentials()
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_network_failure_recovery(
self,
mock_http,
flask_app_with_containers,
db_session_with_containers,
tenant_id_1,
category,
firecrawl_credentials,
):
mock_http.side_effect = httpx.RequestError("Network timeout")
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
with pytest.raises(httpx.RequestError):
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
db_session_with_containers.expire_all()
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all()
assert len(bindings) == 0
@pytest.mark.parametrize(
("provider", "credentials"),
[
(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}),
(AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}),
(AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}),
],
)
def test_all_providers_factory_creation(self, provider, credentials):
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class is not None
factory = ApiKeyAuthFactory(provider, credentials)
assert factory.auth is not None
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_get_auth_credentials_returns_stored_credentials(
self,
mock_http,
mock_encrypt,
flask_app_with_containers,
db_session_with_containers,
tenant_id_1,
category,
firecrawl_credentials,
):
mock_http.return_value = self._create_success_response()
mock_encrypt.return_value = "encrypted_key"
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL)
assert result is not None
assert result["config"]["api_key"] == "encrypted_key"
def _create_success_response(self, status_code=200):
mock_response = Mock()
mock_response.status_code = status_code
mock_response.json.return_value = {"status": "success"}
mock_response.raise_for_status.return_value = None
return mock_response

View File

@@ -8,15 +8,27 @@ verification, marketplace upgrade flows, and uninstall with credential cleanup.
from __future__ import annotations
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy import select
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.entities.plugin_daemon import PluginVerification
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import PluginInstallationScope
from services.plugin.plugin_service import PluginService
from tests.unit_tests.services.plugin.conftest import make_features
def _make_features(
restrict_to_marketplace: bool = False,
scope: PluginInstallationScope = PluginInstallationScope.ALL,
) -> MagicMock:
features = MagicMock()
features.plugin_installation_permission.restrict_to_marketplace_only = restrict_to_marketplace
features.plugin_installation_permission.plugin_installation_scope = scope
return features
class TestFetchLatestPluginVersion:
@@ -80,14 +92,14 @@ class TestFetchLatestPluginVersion:
class TestCheckMarketplaceOnlyPermission:
@patch("services.plugin.plugin_service.FeatureService")
def test_raises_when_restricted(self, mock_fs):
mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=True)
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True)
with pytest.raises(PluginInstallationForbiddenError):
PluginService._check_marketplace_only_permission()
@patch("services.plugin.plugin_service.FeatureService")
def test_passes_when_not_restricted(self, mock_fs):
mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=False)
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False)
PluginService._check_marketplace_only_permission() # should not raise
@@ -95,7 +107,7 @@ class TestCheckMarketplaceOnlyPermission:
class TestCheckPluginInstallationScope:
@patch("services.plugin.plugin_service.FeatureService")
def test_official_only_allows_langgenius(self, mock_fs):
mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
verification = MagicMock()
verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius
@@ -103,14 +115,14 @@ class TestCheckPluginInstallationScope:
@patch("services.plugin.plugin_service.FeatureService")
def test_official_only_rejects_third_party(self, mock_fs):
mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
with pytest.raises(PluginInstallationForbiddenError):
PluginService._check_plugin_installation_scope(None)
@patch("services.plugin.plugin_service.FeatureService")
def test_official_and_partners_allows_partner(self, mock_fs):
mock_fs.get_system_features.return_value = make_features(
mock_fs.get_system_features.return_value = _make_features(
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
)
verification = MagicMock()
@@ -120,7 +132,7 @@ class TestCheckPluginInstallationScope:
@patch("services.plugin.plugin_service.FeatureService")
def test_official_and_partners_rejects_none(self, mock_fs):
mock_fs.get_system_features.return_value = make_features(
mock_fs.get_system_features.return_value = _make_features(
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
)
@@ -129,7 +141,7 @@ class TestCheckPluginInstallationScope:
@patch("services.plugin.plugin_service.FeatureService")
def test_none_scope_always_raises(self, mock_fs):
mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.NONE)
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE)
verification = MagicMock()
verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius
@@ -138,7 +150,7 @@ class TestCheckPluginInstallationScope:
@patch("services.plugin.plugin_service.FeatureService")
def test_all_scope_passes_any(self, mock_fs):
mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.ALL)
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL)
PluginService._check_plugin_installation_scope(None) # should not raise
@@ -209,9 +221,9 @@ class TestUpgradePluginWithMarketplace:
@patch("services.plugin.plugin_service.dify_config")
def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace):
mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = make_features()
mock_fs.get_system_features.return_value = _make_features()
installer = mock_installer_cls.return_value
installer.fetch_plugin_manifest.return_value = MagicMock() # no exception = already installed
installer.fetch_plugin_manifest.return_value = MagicMock()
installer.upgrade_plugin.return_value = MagicMock()
PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid")
@@ -225,7 +237,7 @@ class TestUpgradePluginWithMarketplace:
@patch("services.plugin.plugin_service.dify_config")
def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download):
mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = make_features()
mock_fs.get_system_features.return_value = _make_features()
installer = mock_installer_cls.return_value
installer.fetch_plugin_manifest.side_effect = RuntimeError("not found")
mock_download.return_value = b"pkg-bytes"
@@ -244,7 +256,7 @@ class TestUpgradePluginWithGithub:
@patch("services.plugin.plugin_service.FeatureService")
@patch("services.plugin.plugin_service.PluginInstaller")
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs):
mock_fs.get_system_features.return_value = make_features()
mock_fs.get_system_features.return_value = _make_features()
installer = mock_installer_cls.return_value
installer.upgrade_plugin.return_value = MagicMock()
@@ -259,7 +271,7 @@ class TestUploadPkg:
@patch("services.plugin.plugin_service.FeatureService")
@patch("services.plugin.plugin_service.PluginInstaller")
def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs):
mock_fs.get_system_features.return_value = make_features()
mock_fs.get_system_features.return_value = _make_features()
upload_resp = MagicMock()
upload_resp.verification = None
mock_installer_cls.return_value.upload_pkg.return_value = upload_resp
@@ -283,7 +295,7 @@ class TestInstallFromMarketplacePkg:
@patch("services.plugin.plugin_service.dify_config")
def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download):
mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = make_features()
mock_fs.get_system_features.return_value = _make_features()
installer = mock_installer_cls.return_value
installer.fetch_plugin_manifest.side_effect = RuntimeError("not found")
mock_download.return_value = b"pkg"
@@ -298,14 +310,14 @@ class TestInstallFromMarketplacePkg:
assert result == "task-id"
installer.install_from_identifiers.assert_called_once()
call_args = installer.install_from_identifiers.call_args[0]
assert call_args[1] == ["resolved-uid"] # uses response uid, not input
assert call_args[1] == ["resolved-uid"]
@patch("services.plugin.plugin_service.FeatureService")
@patch("services.plugin.plugin_service.PluginInstaller")
@patch("services.plugin.plugin_service.dify_config")
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs):
mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = make_features()
mock_fs.get_system_features.return_value = _make_features()
installer = mock_installer_cls.return_value
installer.fetch_plugin_manifest.return_value = MagicMock()
decode_resp = MagicMock()
@@ -317,7 +329,7 @@ class TestInstallFromMarketplacePkg:
installer.install_from_identifiers.assert_called_once()
call_args = installer.install_from_identifiers.call_args[0]
assert call_args[1] == ["uid-1"] # uses original uid
assert call_args[1] == ["uid-1"]
class TestUninstall:
@@ -332,26 +344,70 @@ class TestUninstall:
assert result is True
installer.uninstall.assert_called_once_with("t1", "install-1")
@patch("services.plugin.plugin_service.db")
@patch("services.plugin.plugin_service.PluginInstaller")
def test_cleans_credentials_when_plugin_found(self, mock_installer_cls, mock_db):
def test_cleans_credentials_when_plugin_found(
self, mock_installer_cls, flask_app_with_containers, db_session_with_containers
):
tenant_id = str(uuid4())
plugin_id = "org/myplugin"
provider_name = f"{plugin_id}/model-provider"
credential = ProviderCredential(
tenant_id=tenant_id,
provider_name=provider_name,
credential_name="default",
encrypted_config="{}",
)
db_session_with_containers.add(credential)
db_session_with_containers.flush()
credential_id = credential.id
provider = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
credential_id=credential_id,
)
db_session_with_containers.add(provider)
db_session_with_containers.flush()
provider_id = provider.id
pref = TenantPreferredModelProvider(
tenant_id=tenant_id,
provider_name=provider_name,
preferred_provider_type="custom",
)
db_session_with_containers.add(pref)
db_session_with_containers.commit()
plugin = MagicMock()
plugin.installation_id = "install-1"
plugin.plugin_id = "org/myplugin"
plugin.plugin_id = plugin_id
installer = mock_installer_cls.return_value
installer.list_plugins.return_value = [plugin]
installer.uninstall.return_value = True
# Mock Session context manager
mock_session = MagicMock()
mock_db.engine = MagicMock()
mock_session.scalars.return_value.all.return_value = [] # no credentials found
with patch("services.plugin.plugin_service.Session") as mock_session_cls:
mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_session_cls.return_value.__exit__ = MagicMock(return_value=False)
result = PluginService.uninstall("t1", "install-1")
with patch("services.plugin.plugin_service.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = False
result = PluginService.uninstall(tenant_id, "install-1")
assert result is True
installer.uninstall.assert_called_once()
db_session_with_containers.expire_all()
remaining_creds = db_session_with_containers.scalars(
select(ProviderCredential).where(ProviderCredential.id == credential_id)
).all()
assert len(remaining_creds) == 0
updated_provider = db_session_with_containers.get(Provider, provider_id)
assert updated_provider is not None
assert updated_provider.credential_id is None
remaining_prefs = db_session_with_containers.scalars(
select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name == provider_name,
)
).all()
assert len(remaining_prefs) == 0

Some files were not shown because too many files have changed in this diff Show More