Compare commits

...

78 Commits

Author SHA1 Message Date
Asuka Minato
25c69ac540 one example of Session (#24135)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-09-24 03:32:48 +08:00
QuantumGhost
96a0b9991e fix(api): Fix variable truncation for list[File] value in output mapping (#26133)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-09-23 21:30:46 +08:00
QuantumGhost
2913d17fe2 ci: Add hotfix/** branches to build-push workflow triggers (#26129) 2025-09-23 18:48:02 +08:00
Wu Tianwei
d9e45a1abe feat(pipeline): add language support to built-in pipeline templates and update related components (#26124) 2025-09-23 18:18:22 +08:00
longbingljw
24b4289d6c fix:add some explanation for oceanbase parser selection (#26071)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-23 17:06:06 +08:00
GuanMu
fb6ccccc3d chore: refactor component exports for consistency (#26033)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-09-23 17:04:56 +08:00
17hz
8b74ae683a bump nextjs to 15.5 and turbopack for development mode (#24346)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
2025-09-23 16:59:26 +08:00
Jyong
dd08957381 fix full_text_search name (#26104) 2025-09-23 16:40:26 +08:00
quicksand
407323f817 fix(api): graph engine debug logging NodeRunRetryEvent not effective (#26085)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-09-23 13:46:45 +08:00
-LAN-
2e2c87c5a1 fix(graph_engine): error strategy fall. (#26078)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-09-23 01:51:43 +08:00
Asuka Minato
f4522fd695 try contextmanager (#26074)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-23 00:35:54 +08:00
夏目猫猫
760a2c656c amend regexp exec (#25986) 2025-09-23 00:47:13 +09:00
Asuka Minato
8940decd1b more httpx (#25651)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-09-22 23:07:09 +08:00
Jyong
0c4193bd91 fix avatar-url to text (#26068)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-09-22 21:28:42 +08:00
Jyong
cd40cde790 fix tenant not exist (#26066) 2025-09-22 20:50:30 +08:00
Jyong
c60c754ac9 fix preview url (#26059) 2025-09-22 19:47:39 +08:00
非法操作
ef80d3b707 fix: Ensure compatibility with old provider name when updating model credentials (#26017) 2025-09-22 19:39:17 +08:00
QuantumGhost
24e8d21b3f chore(api): bump version (#25917) 2025-09-22 19:14:43 +08:00
Novice
d823da18db fix: iteration and loop node single step run (#26036) 2025-09-22 19:14:24 +08:00
QuantumGhost
1e3df09fc6 chore(api): adjust monkey patching in gunicorn.conf.py (#26056) 2025-09-22 18:23:01 +08:00
Stream
75a10c276c chore: remove mistakenly added trash file (#26041) 2025-09-22 16:07:02 +08:00
Hunter
50050527eb fix: Correctly map source_url to preview_url in file fields (#25957) 2025-09-22 14:31:49 +08:00
Wu Tianwei
a39b185627 fix: comment out unused segmentation rule properties in RuleDetail component (#26031)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-09-22 14:17:02 +08:00
dependabot[bot]
15270f09af chore(deps): bump boto3-stubs from 1.40.29 to 1.40.35 in /api (#26014)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-22 12:16:05 +08:00
QuantumGhost
f6a5ac0698 chore(api): upgade Gevent to 25.9.1 (#26026) 2025-09-22 12:15:50 +08:00
zyssyz123
2b79da722b fix: workflow (#26030) 2025-09-22 12:08:15 +08:00
-LAN-
71d69e43cd Align dev workflow branch triggers (#26029) 2025-09-22 11:56:28 +08:00
Yongtao Huang
5bc6e8a433 Fix: correct regex for file-preview URL re-signing (#25620)
Fixes #25619

The regex patterns for file-preview and image-preview contained an unescaped `?`, 
which caused incorrect matches such as `file-previe` or `image-previw`. 
This led to malformed URLs being incorrectly re-signed.

Changes:
- Escape `?` in both file-preview and image-preview regex patterns.
- Ensure only valid URLs are re-signed.

Added unit tests to cover:
- Valid file-preview and image-preview URLs (correctly re-signed).
- Misspelled file/image preview URLs (no longer incorrectly matched).

Other:
- Fix a deprecated function `datetime.utcnow()`

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-09-22 10:58:29 +08:00
dependabot[bot]
68076f2e22 chore(deps): bump abcjs from 6.5.1 to 6.5.2 in /web (#26018)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-22 10:36:40 +08:00
Wu Tianwei
8c38363038 fix: pass operation name to onUpdate callback in StatusItem component (#26019) 2025-09-22 10:19:12 +08:00
Shili Cao
345ac8333c Add Full-Text & Hybrid Search Support to Baidu Vector DB and Update SDK, Closes #25982 (#25983)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-09-22 10:17:35 +08:00
dependabot[bot]
2375047ef0 chore(deps-dev): bump eslint-plugin-storybook from 0.11.6 to 9.0.7 in /web (#26011)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-09-22 10:03:02 +08:00
Yongtao Huang
857a48012e Fix: use data.type instead of type when checking datasource node (#25965)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-22 10:01:21 +08:00
longbingljw
208fe3d7de feat:support selecting different ftparser for OceanBase. (#25970)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-22 09:56:33 +08:00
dependabot[bot]
92cddbcc02 chore(deps): bump negotiator from 0.6.4 to 1.0.0 in /web (#26012)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-22 09:55:00 +08:00
dependabot[bot]
599b53c9cb chore(deps): bump authlib from 1.3.1 to 1.6.4 in /api (#26015)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-22 09:54:52 +08:00
jiangbo721
062b173c66 fix: Statistics, like workflows, do not include debug data. (#25979)
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2025-09-20 10:47:59 +08:00
Yongtao Huang
db690013fd Chore: remove dead code in datasource.utils (#25984) 2025-09-20 10:47:52 +08:00
lyzno1
e93bfe3d41 fix: resolve chat sidebar UI bugs for hover panel and dropdown menu (#25813)
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Check i18n Files and Create PR / check-and-update (push) Has been cancelled
2025-09-19 18:28:49 +08:00
GuanMu
ab910c736c feat(goto-anything): add RAG pipeline node search (#25948) 2025-09-19 18:28:13 +08:00
Yeuoly
4047a6bb12 fix: ensure original response are maintained by yielding text messages in ApiTool (#23456) (#25973) 2025-09-19 18:27:33 +08:00
github-actions[bot]
df2478dc26 chore: translate i18n files and update type definitions (#25964)
Co-authored-by: WTW0313 <30284043+WTW0313@users.noreply.github.com>
2025-09-19 18:27:09 +08:00
-LAN-
4cc3f6045b Run import-linter within make lint (#25933) 2025-09-19 18:26:43 +08:00
Joel
1550316b8d fix: undefined match the wrong output schema (#25971) 2025-09-19 17:03:09 +08:00
Wu Tianwei
87394d2512 fix: enhance model parameter handling with advanced mode support and localization updates (#25963) 2025-09-19 15:47:52 +08:00
Wu Tianwei
bad59c95bc fix: update details display to conditionally show creator information (#25952) 2025-09-19 15:45:45 +08:00
Xiyuan Chen
9f138ef246 Refactor WorkflowService to handle missing default credentials gracef… (#25960) 2025-09-19 00:45:35 -07:00
zxhlyh
6453fc4973 fix: refresh datasource list after install datasource (#25949)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-09-19 11:03:45 +08:00
GuanMu
f62f926537 style: update GotoAnything component styling (#25929)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-09-19 10:36:43 +08:00
Yongtao Huang
b3dafd913b Chore: correct inconsistent logging and typo (#25945) 2025-09-19 10:36:16 +08:00
-LAN-
b2d8a7eaf1 Fix: enforce editor-only access to chat message logs (#25936)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-09-18 21:59:51 +08:00
GuanMu
3e54414191 chore: update post_create_command.sh to use dynamic workspace root for aliases (#25913) 2025-09-18 21:09:43 +08:00
-LAN-
a173546c8d Fix: replace stdout prints with debug logging (#25931)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-09-18 21:03:20 +08:00
-LAN-
aa69d90489 fix(makefile): correct uv project path for lint target (#25818) 2025-09-18 20:36:26 +08:00
-LAN-
4ba1292455 refactor: replace print statements with proper logging (#25773) 2025-09-18 20:35:47 +08:00
Maries
bb01c31f30 fix(api): enhance data handling in RagPipelineDslService to filter credentials (#25926)
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Check i18n Files and Create PR / check-and-update (push) Has been cancelled
2025-09-18 18:36:49 +08:00
Wu Tianwei
cd90b2ca9e refactor: replace useInvalid with useInvalidCustomizedTemplateList (#25924) 2025-09-18 18:17:20 +08:00
heyszt
9a65350cf7 fix: rollback aliyun_trace icon (#25921) 2025-09-18 18:01:08 +08:00
quicksand
680eb7a9f6 fix(datasets): retrieval_model null issue when updating dataset info (#25907) 2025-09-18 17:58:06 +08:00
crazywoola
878420463c fix: Message => str (#25876) 2025-09-18 17:57:57 +08:00
zxhlyh
4692e20daf fix: workflow header style (#25922) 2025-09-18 17:53:40 +08:00
QuantumGhost
13fe2ca8fe fix(api): fix single stepping variable loading (#25908) 2025-09-18 17:30:02 +08:00
zxhlyh
1264e7d4f6 fix: use invalid last run (#25911) 2025-09-18 16:52:27 +08:00
Yunlu Wen
4f45978cd9 fix: remote code execution in email endpoints (#25753)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-09-18 16:45:34 +08:00
Saurabh Singh
5a0bf8e028 feat: make SQLALCHEMY_POOL_TIMEOUT configurable (#25468)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-18 16:13:56 +08:00
Wu Tianwei
ffa163a8a8 refactor: simplify portal interactions and manage state in Configure component (#25906)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-18 15:57:33 +08:00
Novice
8f86f5749d chore: Update the value of sys.dialogue_count to start from 1. (#25905) 2025-09-18 15:52:52 +08:00
17hz
00d3bf15f3 perf(web): optimize ESLint performance with concurrency flag and remove oxlint (#25899)
Co-authored-by: Claude <noreply@anthropic.com>
2025-09-18 15:50:42 +08:00
17hz
7196c09e9d chore(workflows): remove redundant eslint command from style workflow (#25900) 2025-09-18 15:50:09 +08:00
zxhlyh
fadd9e0bf4 fix: workflow logs list (#25903) 2025-09-18 15:45:37 +08:00
zxhlyh
d8b4bbe067 fix: datasource pinned list (#25896) 2025-09-18 14:52:33 +08:00
GuanMu
24611e375a fix: update Python base image to use bullseye variant (#25895) 2025-09-18 14:38:56 +08:00
lyzno1
ccec582cea chore: add missing template translations in ja-JP (#25892) 2025-09-18 14:37:26 +08:00
Bowen Liang
b2e4107c17 chore: improve opendal storage and ensure closing file after reading files in load_stream method (#25874) 2025-09-18 14:09:19 +08:00
quicksand
87aa070486 feat(api/commands): add migrate-oss to migrate from Local/OpenDAL to … (#25828) 2025-09-18 14:09:00 +08:00
Novice
21230a8eb2 fix: handle None description in MCP tool transformation (#25872)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-09-18 13:11:38 +08:00
-LAN-
85cda47c70 feat: knowledge pipeline (#25360)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: quicksand <quicksandzn@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry <xh001x@hotmail.com>
2025-09-18 12:49:10 +08:00
zyssyz123
7dadb33003 fix: remove billing cache when add or delete app or member (#25885)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-09-18 12:18:07 +08:00
1846 changed files with 104304 additions and 32874 deletions

View File

@@ -1,4 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:3.12 FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev && apt-get -y install libgmp-dev libmpfr-dev libmpc-dev

View File

@@ -1,15 +1,16 @@
#!/bin/bash #!/bin/bash
WORKSPACE_ROOT=$(pwd)
corepack enable corepack enable
cd web && pnpm install cd web && pnpm install
pipx install uv pipx install uv
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
source /home/vscode/.bashrc source /home/vscode/.bashrc

View File

@@ -8,6 +8,7 @@ on:
- "deploy/enterprise" - "deploy/enterprise"
- "build/**" - "build/**"
- "release/e-*" - "release/e-*"
- "hotfix/**"
tags: tags:
- "*" - "*"

View File

@@ -12,12 +12,13 @@ jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: | if: |
github.event.workflow_run.conclusion == 'success' github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/dev'
steps: steps:
- name: Deploy to server - name: Deploy to server
uses: appleboy/ssh-action@v0.1.8 uses: appleboy/ssh-action@v0.1.8
with: with:
host: ${{ secrets.SSH_HOST }} host: ${{ secrets.RAG_SSH_HOST }}
username: ${{ secrets.SSH_USER }} username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }} key: ${{ secrets.SSH_PRIVATE_KEY }}
script: | script: |

View File

@@ -12,7 +12,6 @@ permissions:
statuses: write statuses: write
contents: read contents: read
jobs: jobs:
python-style: python-style:
name: Python Style name: Python Style
@@ -44,6 +43,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev run: uv sync --project api --dev
- name: Run Import Linter
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports
- name: Run Basedpyright Checks - name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check run: dev/basedpyright-check
@@ -99,7 +102,6 @@ jobs:
working-directory: ./web working-directory: ./web
run: | run: |
pnpm run lint pnpm run lint
pnpm run eslint
docker-compose-template: docker-compose-template:
name: Docker Compose Template name: Docker Compose Template

6
.gitignore vendored
View File

@@ -230,4 +230,8 @@ api/.env.backup
# Benchmark # Benchmark
scripts/stress-test/setup/config/ scripts/stress-test/setup/config/
scripts/stress-test/reports/ scripts/stress-test/reports/
# mcp
.playwright-mcp/
.serena/

View File

@@ -61,8 +61,9 @@ check:
@echo "✅ Code check complete" @echo "✅ Code check complete"
lint: lint:
@echo "🔧 Running ruff format and check with fixes..." @echo "🔧 Running ruff format, check with fixes, and import linter..."
@uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api' @uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
@uv run --directory api --dev lint-imports
@echo "✅ Linting complete" @echo "✅ Linting complete"
type-check: type-check:

View File

@@ -76,6 +76,7 @@ DB_HOST=localhost
DB_PORT=5432 DB_PORT=5432
DB_DATABASE=dify DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true SQLALCHEMY_POOL_PRE_PING=true
SQLALCHEMY_POOL_TIMEOUT=30
# Storage configuration # Storage configuration
# use for store upload files, private keys... # use for store upload files, private keys...
@@ -303,6 +304,8 @@ BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_REPLICAS=3
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
# Upstash configuration # Upstash configuration
UPSTASH_VECTOR_URL=your-server-url UPSTASH_VECTOR_URL=your-server-url
@@ -461,6 +464,16 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800 MAX_VARIABLE_SIZE=204800
# GraphEngine Worker Pool Configuration
# Minimum number of workers per GraphEngine instance (default: 1)
GRAPH_ENGINE_MIN_WORKERS=1
# Maximum number of workers per GraphEngine instance (default: 10)
GRAPH_ENGINE_MAX_WORKERS=10
# Queue depth threshold that triggers worker scale up (default: 3)
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
# Seconds of idle time before scaling down workers (default: 5.0)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
# Workflow storage configuration # Workflow storage configuration
# Options: rdbms, hybrid # Options: rdbms, hybrid
# rdbms: Use only the relational database (default) # rdbms: Use only the relational database (default)

105
api/.importlinter Normal file
View File

@@ -0,0 +1,105 @@
[importlinter]
root_packages =
core
configs
controllers
models
tasks
services
[importlinter:contract:workflow]
name = Workflow
type=layers
layers =
graph_engine
graph_events
graph
nodes
node_events
entities
containers =
core.workflow
ignore_imports =
core.workflow.nodes.base.node -> core.workflow.graph_events
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
core.workflow.nodes.node_factory -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
[importlinter:contract:rsc]
name = RSC
type = layers
layers =
graph_engine
response_coordinator
containers =
core.workflow.graph_engine
[importlinter:contract:worker]
name = Worker
type = layers
layers =
graph_engine
worker
containers =
core.workflow.graph_engine
[importlinter:contract:graph-engine-architecture]
name = Graph Engine Architecture
type = layers
layers =
graph_engine
orchestration
command_processing
event_management
error_handler
graph_traversal
graph_state_manager
worker_management
domain
containers =
core.workflow.graph_engine
[importlinter:contract:domain-isolation]
name = Domain Model Isolation
type = forbidden
source_modules =
core.workflow.graph_engine.domain
forbidden_modules =
core.workflow.graph_engine.worker_management
core.workflow.graph_engine.command_channels
core.workflow.graph_engine.layers
core.workflow.graph_engine.protocols
[importlinter:contract:worker-management]
name = Worker Management
type = forbidden
source_modules =
core.workflow.graph_engine.worker_management
forbidden_modules =
core.workflow.graph_engine.orchestration
core.workflow.graph_engine.command_processing
core.workflow.graph_engine.event_management
[importlinter:contract:graph-traversal-components]
name = Graph Traversal Components
type = layers
layers =
edge_processor
skip_propagator
containers =
core.workflow.graph_engine.graph_traversal
[importlinter:contract:command-channels]
name = Command Channels Independence
type = independence
modules =
core.workflow.graph_engine.command_channels.in_memory_channel
core.workflow.graph_engine.command_channels.redis_channel

View File

@@ -30,6 +30,7 @@ select = [
"RUF022", # unsorted-dunder-all "RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load "S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules "SIM", # flake8-simplify rules
"T201", # print-found
"TRY400", # error-instead-of-exception "TRY400", # error-instead-of-exception
"TRY401", # verbose-log-message "TRY401", # verbose-log-message
"UP", # pyupgrade rules "UP", # pyupgrade rules
@@ -91,11 +92,18 @@ ignore = [
"configs/*" = [ "configs/*" = [
"N802", # invalid-function-name "N802", # invalid-function-name
] ]
"core/model_runtime/callbacks/base_callback.py" = [
"T201",
]
"core/workflow/callbacks/workflow_logging_callback.py" = [
"T201",
]
"libs/gmpy2_pkcs10aep_cipher.py" = [ "libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name "N803", # invalid-argument-name
] ]
"tests/*" = [ "tests/*" = [
"F811", # redefined-while-unused "F811", # redefined-while-unused
"T201", # allow print in tests
] ]
[lint.pyflakes] [lint.pyflakes]

View File

@@ -1,4 +1,3 @@
import os
import sys import sys
@@ -17,20 +16,20 @@ else:
# It seems that JetBrains Python debugger does not work well with gevent, # It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode. # so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey # from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# gevent # import psycogreen.gevent # type: ignore
monkey.patch_all() #
# psycogreen.gevent.patch_psycopg()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
from app_factory import create_app from app_factory import create_app

13
api/celery_entrypoint.py Normal file
View File

@@ -0,0 +1,13 @@
import psycogreen.gevent as pscycogreen_gevent # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
print("gRPC patched with gevent.", flush=True) # noqa: T201
pscycogreen_gevent.patch_psycopg()
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
from app import app, celery
__all__ = ["app", "celery"]

View File

@@ -1,7 +1,6 @@
import base64 import base64
import json import json
import logging import logging
import operator
import secrets import secrets
from typing import Any from typing import Any
@@ -11,32 +10,41 @@ from flask import current_app
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from core.plugin.entities.plugin import ToolProviderID from core.helper import encrypter
from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created from events.app_event import app_was_created
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType
from libs.helper import email as email_validate from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models import Tenant from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -54,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip(): if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red")) click.echo(click.style("Passwords do not match.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = db.session.query(Account).where(Account.email == email).one_or_none() if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account: try:
click.echo(click.style(f"Account not found for email: {email}", fg="red")) valid_password(new_password)
return except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
try: # generate password salt
valid_password(new_password) salt = secrets.token_bytes(16)
except: base64_salt = base64.b64encode(salt).decode()
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
# generate password salt # encrypt password with salt
salt = secrets.token_bytes(16) password_hashed = hash_password(new_password, salt)
base64_salt = base64.b64encode(salt).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
# encrypt password with salt account.password_salt = base64_salt
password_hashed = hash_password(new_password, salt) AccountService.reset_login_error_rate_limit(email)
base64_password_hashed = base64.b64encode(password_hashed).decode() click.echo(click.style("Password reset successfully.", fg="green"))
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.") @click.command("reset-email", help="Reset the account email.")
@@ -93,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip(): if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red")) click.echo(click.style("New emails do not match.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = db.session.query(Account).where(Account.email == email).one_or_none() if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account: try:
click.echo(click.style(f"Account not found for email: {email}", fg="red")) email_validate(new_email)
return except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
try: account.email = new_email
email_validate(new_email) click.echo(click.style("Email updated successfully.", fg="green"))
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command( @click.command(
@@ -132,25 +138,24 @@ def reset_encrypt_key_pair():
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return
tenants = db.session.query(Tenant).all() tenant.encrypt_public_key = generate_key_pair(tenant.id)
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id) session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() click.echo(
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() click.style(
db.session.commit() f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
click.echo( )
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
) )
)
@click.command("vdb-migrate", help="Migrate vector db.") @click.command("vdb-migrate", help="Migrate vector db.")
@@ -175,14 +180,15 @@ def migrate_annotation_vector_database():
try: try:
# get apps info # get apps info
per_page = 50 per_page = 50
apps = ( with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
db.session.query(App) apps = (
.where(App.status == "normal") session.query(App)
.order_by(App.created_at.desc()) .where(App.status == "normal")
.limit(per_page) .order_by(App.created_at.desc())
.offset((page - 1) * per_page) .limit(per_page)
.all() .offset((page - 1) * per_page)
) .all()
)
if not apps: if not apps:
break break
except SQLAlchemyError: except SQLAlchemyError:
@@ -196,26 +202,27 @@ def migrate_annotation_vector_database():
) )
try: try:
click.echo(f"Creating app annotation index: {app.id}") click.echo(f"Creating app annotation index: {app.id}")
app_annotation_setting = ( with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() app_annotation_setting = (
) session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting: if not app_annotation_setting:
skipped_count = skipped_count + 1 skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}") click.echo(f"App annotation setting disabled: {app.id}")
continue continue
# get dataset_collection_binding info # get dataset_collection_binding info
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first() .first()
) )
if not dataset_collection_binding: if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}") click.echo(f"App annotation collection binding not found: {app.id}")
continue continue
annotations = db.session.scalars( annotations = session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all() ).all()
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@@ -732,18 +739,18 @@ where sites.id is null limit 1000"""
try: try:
app = db.session.query(App).where(App.id == app_id).first() app = db.session.query(App).where(App.id == app_id).first()
if not app: if not app:
print(f"App {app_id} not found") logger.info("App %s not found", app_id)
continue continue
tenant = app.tenant tenant = app.tenant
if tenant: if tenant:
accounts = tenant.get_accounts() accounts = tenant.get_accounts()
if not accounts: if not accounts:
print(f"Fix failed for app {app.id}") logger.info("Fix failed for app %s", app.id)
continue continue
account = accounts[0] account = accounts[0]
print(f"Fixing missing site for app {app.id}") logger.info("Fixing missing site for app %s", app.id)
app_was_created.send(app, account=account) app_was_created.send(app, account=account)
except Exception: except Exception:
failed_app_ids.append(app_id) failed_app_ids.append(app_id)
@@ -1246,15 +1253,17 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
def _count_orphaned_draft_variables() -> dict[str, Any]: def _count_orphaned_draft_variables() -> dict[str, Any]:
""" """
Count orphaned draft variables by app. Count orphaned draft variables by app, including associated file counts.
Returns: Returns:
Dictionary with statistics about orphaned variables Dictionary with statistics about orphaned variables and files
""" """
query = """ # Count orphaned variables by app
variables_query = """
SELECT SELECT
wdv.app_id, wdv.app_id,
COUNT(*) as variable_count COUNT(*) as variable_count,
COUNT(wdv.file_id) as file_count
FROM workflow_draft_variables AS wdv FROM workflow_draft_variables AS wdv
WHERE NOT EXISTS( WHERE NOT EXISTS(
SELECT 1 FROM apps WHERE apps.id = wdv.app_id SELECT 1 FROM apps WHERE apps.id = wdv.app_id
@@ -1264,14 +1273,21 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
""" """
with db.engine.connect() as conn: with db.engine.connect() as conn:
result = conn.execute(sa.text(query)) result = conn.execute(sa.text(variables_query))
orphaned_by_app = {row[0]: row[1] for row in result} orphaned_by_app = {}
total_files = 0
total_orphaned = sum(orphaned_by_app.values()) for row in result:
app_id, variable_count, file_count = row
orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count}
total_files += file_count
total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values())
app_count = len(orphaned_by_app) app_count = len(orphaned_by_app)
return { return {
"total_orphaned_variables": total_orphaned, "total_orphaned_variables": total_orphaned,
"total_orphaned_files": total_files,
"orphaned_app_count": app_count, "orphaned_app_count": app_count,
"orphaned_by_app": orphaned_by_app, "orphaned_by_app": orphaned_by_app,
} }
@@ -1300,6 +1316,7 @@ def cleanup_orphaned_draft_variables(
stats = _count_orphaned_draft_variables() stats = _count_orphaned_draft_variables()
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"]) logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
logger.info("Found %s associated offload files", stats["total_orphaned_files"])
logger.info("Across %s non-existent apps", stats["orphaned_app_count"]) logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
if stats["total_orphaned_variables"] == 0: if stats["total_orphaned_variables"] == 0:
@@ -1308,10 +1325,10 @@ def cleanup_orphaned_draft_variables(
if dry_run: if dry_run:
logger.info("DRY RUN: Would delete the following:") logger.info("DRY RUN: Would delete the following:")
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[ for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[
:10 :10
]: # Show top 10 ]: # Show top 10
logger.info(" App %s: %s variables", app_id, count) logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"])
if len(stats["orphaned_by_app"]) > 10: if len(stats["orphaned_by_app"]) > 10:
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10) logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
return return
@@ -1320,7 +1337,8 @@ def cleanup_orphaned_draft_variables(
if not force: if not force:
click.confirm( click.confirm(
f"Are you sure you want to delete {stats['total_orphaned_variables']} " f"Are you sure you want to delete {stats['total_orphaned_variables']} "
f"orphaned draft variables from {stats['orphaned_app_count']} apps?", f"orphaned draft variables and {stats['total_orphaned_files']} associated files "
f"from {stats['orphaned_app_count']} apps?",
abort=True, abort=True,
) )
@@ -1353,3 +1371,456 @@ def cleanup_orphaned_draft_variables(
continue continue
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps) logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_datasource_oauth_client(provider, client_params):
"""
Setup datasource oauth client
"""
provider_id = DatasourceProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
deleted_count = (
db.session.query(DatasourceOauthParamConfig)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
oauth_client = DatasourceOauthParamConfig(
provider=provider_name,
plugin_id=plugin_id,
system_credentials=client_params_dict,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"provider: {provider_name}", fg="green"))
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials():
"""
Transform datasource credentials
"""
try:
installer_manager = PluginInstaller()
plugin_migration = PluginMigration()
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
# deal notion credentials
deal_notion_count = 0
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
if notion_credentials:
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for notion_credential in notion_credentials:
tenant_id = notion_credential.tenant_id
if tenant_id not in notion_credentials_tenant_mapping:
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if notion_plugin_id not in installed_plugins_ids:
if notion_plugin_unique_identifier:
# install notion plugin
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
auth_count = 0
for notion_tenant_credential in notion_tenant_credentials:
auth_count += 1
# get credential oauth params
access_token = notion_tenant_credential.access_token
# notion info
notion_info = notion_tenant_credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
new_credentials = {
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
"workspace_id": workspace_id,
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
}
datasource_provider = DatasourceProvider(
provider="notion_datasource",
tenant_id=tenant_id,
plugin_id=notion_plugin_id,
auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
db.session.add(datasource_provider)
deal_notion_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
)
continue
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
if firecrawl_credentials:
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for firecrawl_credential in firecrawl_credentials:
tenant_id = firecrawl_credential.tenant_id
if tenant_id not in firecrawl_credentials_tenant_mapping:
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {
"firecrawl_api_key": api_key,
"base_url": base_url,
}
datasource_provider = DatasourceProvider(
provider="firecrawl",
tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
)
continue
db.session.commit()
# deal jina credentials
deal_jina_count = 0
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
if jina_credentials:
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for jina_credential in jina_credentials:
tenant_id = jina_credential.tenant_id
if tenant_id not in jina_credentials_tenant_mapping:
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_jina_count += 1
except Exception as e:
click.echo(
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
)
continue
db.session.commit()
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
click.echo(
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
)
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
@click.option(
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
)
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
def install_rag_pipeline_plugins(input_file, output_file, workers):
"""
Install rag pipeline plugins
"""
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
plugin_migration = PluginMigration()
plugin_migration.install_rag_pipeline_plugins(
input_file,
output_file,
workers,
)
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
@click.command(
"migrate-oss",
help="Migrate files from Local or OpenDAL source to a cloud OSS storage (destination must NOT be local/opendal).",
)
@click.option(
"--path",
"paths",
multiple=True,
help="Storage path prefixes to migrate (repeatable). Defaults: privkeys, upload_files, image_files,"
" tools, website_files, keyword_files, ops_trace",
)
@click.option(
"--source",
type=click.Choice(["local", "opendal"], case_sensitive=False),
default="opendal",
show_default=True,
help="Source storage type to read from",
)
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite destination if file already exists")
@click.option("--dry-run", is_flag=True, default=False, help="Show what would be migrated without uploading")
@click.option("-f", "--force", is_flag=True, help="Skip confirmation and run without prompts")
@click.option(
"--update-db/--no-update-db",
default=True,
help="Update upload_files.storage_type from source type to current storage after migration",
)
def migrate_oss(
paths: tuple[str, ...],
source: str,
overwrite: bool,
dry_run: bool,
force: bool,
update_db: bool,
):
"""
Copy all files under selected prefixes from a source storage
(Local filesystem or OpenDAL-backed) into the currently configured
destination storage backend, then optionally update DB records.
Expected usage: set STORAGE_TYPE (and its credentials) to your target backend.
"""
# Ensure target storage is not local/opendal
if dify_config.STORAGE_TYPE in (StorageType.LOCAL, StorageType.OPENDAL):
click.echo(
click.style(
"Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n"
"Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n"
"volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.",
fg="red",
)
)
return
# Default paths if none specified
default_paths = ("privkeys", "upload_files", "image_files", "tools", "website_files", "keyword_files", "ops_trace")
path_list = list(paths) if paths else list(default_paths)
is_source_local = source.lower() == "local"
click.echo(click.style("Preparing migration to target storage.", fg="yellow"))
click.echo(click.style(f"Target storage type: {dify_config.STORAGE_TYPE}", fg="white"))
if is_source_local:
src_root = dify_config.STORAGE_LOCAL_PATH
click.echo(click.style(f"Source: local fs, root: {src_root}", fg="white"))
else:
click.echo(click.style(f"Source: opendal scheme={dify_config.OPENDAL_SCHEME}", fg="white"))
click.echo(click.style(f"Paths to migrate: {', '.join(path_list)}", fg="white"))
click.echo("")
if not force:
click.confirm("Proceed with migration?", abort=True)
# Instantiate source storage
try:
if is_source_local:
src_root = dify_config.STORAGE_LOCAL_PATH
source_storage = OpenDALStorage(scheme="fs", root=src_root)
else:
source_storage = OpenDALStorage(scheme=dify_config.OPENDAL_SCHEME)
except Exception as e:
click.echo(click.style(f"Failed to initialize source storage: {str(e)}", fg="red"))
return
total_files = 0
copied_files = 0
skipped_files = 0
errored_files = 0
copied_upload_file_keys: list[str] = []
for prefix in path_list:
click.echo(click.style(f"Scanning source path: {prefix}", fg="white"))
try:
keys = source_storage.scan(path=prefix, files=True, directories=False)
except FileNotFoundError:
click.echo(click.style(f" -> Skipping missing path: {prefix}", fg="yellow"))
continue
except NotImplementedError:
click.echo(click.style(" -> Source storage does not support scanning.", fg="red"))
return
except Exception as e:
click.echo(click.style(f" -> Error scanning '{prefix}': {str(e)}", fg="red"))
continue
click.echo(click.style(f"Found {len(keys)} files under {prefix}", fg="white"))
for key in keys:
total_files += 1
# check destination existence
if not overwrite:
try:
if storage.exists(key):
skipped_files += 1
continue
except Exception as e:
# existence check failures should not block migration attempt
# but should be surfaced to user as a warning for visibility
click.echo(
click.style(
f" -> Warning: failed target existence check for {key}: {str(e)}",
fg="yellow",
)
)
if dry_run:
copied_files += 1
continue
# read from source and write to destination
try:
data = source_storage.load_once(key)
except FileNotFoundError:
errored_files += 1
click.echo(click.style(f" -> Missing on source: {key}", fg="yellow"))
continue
except Exception as e:
errored_files += 1
click.echo(click.style(f" -> Error reading {key}: {str(e)}", fg="red"))
continue
try:
storage.save(key, data)
copied_files += 1
if prefix == "upload_files":
copied_upload_file_keys.append(key)
except Exception as e:
errored_files += 1
click.echo(click.style(f" -> Error writing {key} to target: {str(e)}", fg="red"))
continue
click.echo("")
click.echo(click.style("Migration summary:", fg="yellow"))
click.echo(click.style(f" Total: {total_files}", fg="white"))
click.echo(click.style(f" Copied: {copied_files}", fg="green"))
click.echo(click.style(f" Skipped: {skipped_files}", fg="white"))
if errored_files:
click.echo(click.style(f" Errors: {errored_files}", fg="red"))
if dry_run:
click.echo(click.style("Dry-run complete. No changes were made.", fg="green"))
return
if errored_files:
click.echo(
click.style(
"Some files failed to migrate. Review errors above before updating DB records.",
fg="yellow",
)
)
if update_db and not force:
if not click.confirm("Proceed to update DB storage_type despite errors?", default=False):
update_db = False
# Optionally update DB records for upload_files.storage_type (only for successfully copied upload_files)
if update_db:
if not copied_upload_file_keys:
click.echo(click.style("No upload_files copied. Skipping DB storage_type update.", fg="yellow"))
else:
try:
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
updated = (
db.session.query(UploadFile)
.where(
UploadFile.storage_type == source_storage_type,
UploadFile.key.in_(copied_upload_file_keys),
)
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
)
db.session.commit()
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))

View File

@@ -1,3 +1,3 @@
from .app_config import DifyConfig from .app_config import DifyConfig
dify_config = DifyConfig() dify_config = DifyConfig() # type: ignore

View File

@@ -1,3 +1,4 @@
from enum import StrEnum
from typing import Literal from typing import Literal
from pydantic import ( from pydantic import (
@@ -505,6 +506,22 @@ class UpdateConfig(BaseSettings):
) )
class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB
1024_000,
description="Maximum size for variable to trigger final truncation.",
)
WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field(
100000,
description="maximum length for string to trigger tuncation, measure in number of characters",
)
WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field(
1000,
description="maximum length for array to trigger truncation.",
)
class WorkflowConfig(BaseSettings): class WorkflowConfig(BaseSettings):
""" """
Configuration for workflow execution Configuration for workflow execution
@@ -535,6 +552,28 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024, default=200 * 1024,
) )
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",
default=1,
)
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
description="Maximum number of workers per GraphEngine instance",
default=10,
)
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
description="Queue depth threshold that triggers worker scale up",
default=3,
)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
description="Seconds of idle time before scaling down workers",
default=5.0,
ge=0.1,
)
class WorkflowNodeExecutionConfig(BaseSettings): class WorkflowNodeExecutionConfig(BaseSettings):
""" """
@@ -673,11 +712,35 @@ class ToolConfig(BaseSettings):
) )
class TemplateMode(StrEnum):
# unsafe mode allows flexible operations in templates, but may cause security vulnerabilities
UNSAFE = "unsafe"
# sandbox mode restricts some unsafe operations like accessing __class__.
# however, it is still not 100% safe, for example, cpu exploitation can happen.
SANDBOX = "sandbox"
# templating is disabled
DISABLED = "disabled"
class MailConfig(BaseSettings): class MailConfig(BaseSettings):
""" """
Configuration for email services Configuration for email services
""" """
MAIL_TEMPLATING_MODE: TemplateMode = Field(
description="Template mode for email services",
default=TemplateMode.SANDBOX,
)
MAIL_TEMPLATING_TIMEOUT: int = Field(
description="""
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Only available in sandbox mode.""",
default=3,
)
MAIL_TYPE: str | None = Field( MAIL_TYPE: str | None = Field(
description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.", description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
default=None, default=None,
@@ -1041,5 +1104,6 @@ class FeatureConfig(
CeleryBeatConfig, CeleryBeatConfig,
CeleryScheduleTasksConfig, CeleryScheduleTasksConfig,
WorkflowLogConfig, WorkflowLogConfig,
WorkflowVariableTruncationConfig,
): ):
pass pass

View File

@@ -220,11 +220,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
) )
class HostedFetchPipelineTemplateConfig(BaseSettings):
"""
Configuration for fetching pipeline templates
"""
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
default="remote",
)
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="Domain for fetching remote pipeline templates",
default="https://tmpl.dify.ai",
)
class HostedServiceConfig( class HostedServiceConfig(
# place the configs in alphabet order # place the configs in alphabet order
HostedAnthropicConfig, HostedAnthropicConfig,
HostedAzureOpenAiConfig, HostedAzureOpenAiConfig,
HostedFetchAppTemplateConfig, HostedFetchAppTemplateConfig,
HostedFetchPipelineTemplateConfig,
HostedMinmaxConfig, HostedMinmaxConfig,
HostedOpenAiConfig, HostedOpenAiConfig,
HostedSparkConfig, HostedSparkConfig,

View File

@@ -187,6 +187,11 @@ class DatabaseConfig(BaseSettings):
default=False, default=False,
) )
SQLALCHEMY_POOL_TIMEOUT: NonNegativeInt = Field(
description="Number of seconds to wait for a connection from the pool before raising a timeout error.",
default=30,
)
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field( RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
description="Number of processes for the retrieval service, default to CPU cores.", description="Number of processes for the retrieval service, default to CPU cores.",
default=os.cpu_count() or 1, default=os.cpu_count() or 1,
@@ -216,6 +221,7 @@ class DatabaseConfig(BaseSettings):
"connect_args": connect_args, "connect_args": connect_args,
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO, "pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
"pool_reset_on_return": None, "pool_reset_on_return": None,
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
} }

View File

@@ -41,3 +41,13 @@ class BaiduVectorDBConfig(BaseSettings):
description="Number of replicas for the Baidu Vector Database (default is 3)", description="Number of replicas for the Baidu Vector Database (default is 3)",
default=3, default=3,
) )
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field(
description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)",
default="DEFAULT_ANALYZER",
)
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field(
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
default="COARSE_MODE",
)

View File

@@ -37,3 +37,15 @@ class OceanBaseVectorConfig(BaseSettings):
"with older versions", "with older versions",
default=False, default=False,
) )
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
description=(
"Fulltext parser to use for text indexing. "
"Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
),
default="ik",
)

View File

@@ -29,7 +29,7 @@ def no_key_cache_key(namespace: str, key: str) -> str:
# Returns whether the obtained value is obtained, and None if it does not # Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None: def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any:
if namespace_cache: if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS) kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None: if kv_data is None:

View File

@@ -5,7 +5,7 @@ import logging
import os import os
import time import time
import requests import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -30,10 +30,10 @@ class NacosHttpClient:
params = {} params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except requests.RequestException as e: except httpx.RequestError as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None: def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
@@ -78,7 +78,7 @@ class NacosHttpClient:
params = {"username": self.username, "password": self.password} params = {"username": self.username, "password": self.password}
url = "http://" + self.server + "/nacos/v1/auth/login" url = "http://" + self.server + "/nacos/v1/auth/login"
try: try:
resp = requests.request("POST", url, headers=None, params=params) resp = httpx.request("POST", url, headers=None, params=params)
resp.raise_for_status() resp.raise_for_status()
response_data = resp.json() response_data = resp.json()
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING: if TYPE_CHECKING:
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
@@ -32,3 +33,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas") ContextVar("plugin_model_schemas")
) )
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock")
)

View File

@@ -61,6 +61,7 @@ from . import (
init_validate, init_validate,
ping, ping,
setup, setup,
spec,
version, version,
) )
@@ -114,6 +115,15 @@ from .datasets import (
metadata, metadata,
website, website,
) )
from .datasets.rag_pipeline import (
datasource_auth,
datasource_content_preview,
rag_pipeline,
rag_pipeline_datasets,
rag_pipeline_draft_variable,
rag_pipeline_import,
rag_pipeline_workflow,
)
# Import explore controllers # Import explore controllers
from .explore import ( from .explore import (
@@ -238,6 +248,8 @@ __all__ = [
"datasets", "datasets",
"datasets_document", "datasets_document",
"datasets_segments", "datasets_segments",
"datasource_auth",
"datasource_content_preview",
"email_register", "email_register",
"endpoint", "endpoint",
"extension", "extension",
@@ -263,10 +275,16 @@ __all__ = [
"parameter", "parameter",
"ping", "ping",
"plugin", "plugin",
"rag_pipeline",
"rag_pipeline_datasets",
"rag_pipeline_draft_variable",
"rag_pipeline_import",
"rag_pipeline_workflow",
"recommended_app", "recommended_app",
"saved_message", "saved_message",
"setup", "setup",
"site", "site",
"spec",
"statistic", "statistic",
"tags", "tags",
"tool_providers", "tool_providers",

View File

@@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
import pytz # pip install pytz import pytz # pip install pytz
import sqlalchemy as sa
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
@@ -70,7 +71,7 @@ class CompletionConversationApi(Resource):
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
query = db.select(Conversation).where( query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
) )
@@ -236,7 +237,7 @@ class ChatConversationApi(Resource):
.subquery() .subquery()
) )
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]: if args["keyword"]:
keyword_filter = f"%{args['keyword']}%" keyword_filter = f"%{args['keyword']}%"

View File

@@ -16,7 +16,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required from libs.login import login_required
from models import App
from services.workflow_service import WorkflowService
@console_ns.route("/rule-generate") @console_ns.route("/rule-generate")
@@ -205,9 +208,6 @@ class InstructionGenerateApi(Resource):
try: try:
# Generate from nothing for a workflow node # Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
from models import App, db
from services.workflow_service import WorkflowService
app = db.session.query(App).where(App.id == args["flow_id"]).first() app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app: if not app:
return {"error": f"app {args['flow_id']} not found"}, 400 return {"error": f"app {args['flow_id']} not found"}, 400
@@ -261,6 +261,7 @@ class InstructionGenerateApi(Resource):
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
ideal_output=args["ideal_output"], ideal_output=args["ideal_output"],
workflow_service=WorkflowService(),
) )
return {"error": "incompatible parameters"}, 400 return {"error": "incompatible parameters"}, 400
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:

View File

@@ -62,6 +62,9 @@ class ChatMessageListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model): def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args") parser.add_argument("first_id", type=uuid_value, location="args")

View File

@@ -50,8 +50,9 @@ class DailyMessageStatistic(Resource):
FROM FROM
messages messages
WHERE WHERE
app_id = :app_id""" app_id = :app_id
arg_dict = {"tz": account.timezone, "app_id": app_model.id} AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@@ -187,8 +188,9 @@ class DailyTerminalsStatistic(Resource):
FROM FROM
messages messages
WHERE WHERE
app_id = :app_id""" app_id = :app_id
arg_dict = {"tz": account.timezone, "app_id": app_model.id} AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@@ -259,8 +261,9 @@ class DailyTokenCostStatistic(Resource):
FROM FROM
messages messages
WHERE WHERE
app_id = :app_id""" app_id = :app_id
arg_dict = {"tz": account.timezone, "app_id": app_model.id} AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@@ -340,8 +343,9 @@ FROM
messages m messages m
ON c.id = m.conversation_id ON c.id = m.conversation_id
WHERE WHERE
c.app_id = :app_id""" c.app_id = :app_id
arg_dict = {"tz": account.timezone, "app_id": app_model.id} AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@@ -426,8 +430,9 @@ LEFT JOIN
message_feedbacks mf message_feedbacks mf
ON mf.message_id=m.id AND mf.rating='like' ON mf.message_id=m.id AND mf.rating='like'
WHERE WHERE
m.app_id = :app_id""" m.app_id = :app_id
arg_dict = {"tz": account.timezone, "app_id": app_model.id} AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@@ -502,8 +507,9 @@ class AverageResponseTimeStatistic(Resource):
FROM FROM
messages messages
WHERE WHERE
app_id = :app_id""" app_id = :app_id
arg_dict = {"tz": account.timezone, "app_id": app_model.id} AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@@ -576,8 +582,9 @@ class TokensPerSecondStatistic(Resource):
FROM FROM
messages messages
WHERE WHERE
app_id = :app_id""" app_id = :app_id
arg_dict = {"tz": account.timezone, "app_id": app_model.id} AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc

View File

@@ -20,6 +20,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id from core.helper.trace_id_helper import get_external_trace_id
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory, variable_factory from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@@ -536,7 +537,12 @@ class WorkflowTaskStopApi(Resource):
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"} return {"result": "success"}

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required from libs.login import login_required

View File

@@ -13,14 +13,16 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App, AppMode, db from models import App, AppMode
from models.account import Account from models.account import Account
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@@ -74,6 +76,22 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
return value_type.exposed_type().value return value_type.exposed_type().value
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
"""Serialize full_content information for large variables."""
if not variable.is_truncated():
return None
variable_file = variable.variable_file
assert variable_file is not None
return {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String, "id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()), "type": fields.String(attribute=lambda model: model.get_variable_type()),
@@ -83,11 +101,13 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"value_type": fields.String(attribute=_serialize_variable_type), "value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited), "edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean, "visible": fields.Boolean,
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
} }
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value), value=fields.Raw(attribute=_serialize_var_value),
full_content=fields.Raw(attribute=_serialize_full_content),
) )
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {

View File

@@ -1,6 +1,6 @@
import logging import logging
import requests import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
@@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400 return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.HTTPError as e: except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )
@@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
try: try:
oauth_provider.sync_data_source(binding_id) oauth_provider.sync_data_source(binding_id)
except requests.HTTPError as e: except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )

View File

@@ -1,6 +1,6 @@
import logging import logging
import requests import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy import select from sqlalchemy import select
@@ -101,8 +101,10 @@ class OAuthCallback(Resource):
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e: except httpx.RequestError as e:
error_text = e.response.text if e.response else str(e) error_text = str(e)
if isinstance(e, httpx.HTTPStatusError):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400

View File

@@ -1,4 +1,6 @@
import json import json
from collections.abc import Generator
from typing import cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
@@ -9,6 +11,8 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
@@ -19,6 +23,7 @@ from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import DataSourceOauthBinding, Document from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task from tasks.document_indexing_sync_task import document_indexing_sync_task
@@ -111,6 +116,18 @@ class DataSourceNotionListApi(Resource):
@marshal_with(integrate_notion_info_list_fields) @marshal_with(integrate_notion_info_list_fields)
def get(self): def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str) dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = [] exist_page_ids = []
with Session(db.engine) as session: with Session(db.engine) as session:
# import notion in the exist dataset # import notion in the exist dataset
@@ -134,31 +151,49 @@ class DataSourceNotionListApi(Resource):
data_source_info = json.loads(document.data_source_info) data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"]) exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages # get all authorized pages
data_source_bindings = session.scalars( from core.datasource.datasource_manager import DatasourceManager
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
if credential:
datasource_runtime.runtime.credentials = credential
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
) )
).all() )
if not data_source_bindings: try:
return {"notion_info": []}, 200 pages = []
pre_import_info_list = [] workspace_info = {}
for data_source_binding in data_source_bindings: for message in online_document_result:
source_info = data_source_binding.source_info result = message.result
pages = source_info["pages"] for info in result:
# Filter out already bound pages workspace_info = {
for page in pages: "workspace_id": info.workspace_id,
if page["page_id"] in exist_page_ids: "workspace_name": info.workspace_name,
page["is_bound"] = True "workspace_icon": info.workspace_icon,
else: }
page["is_bound"] = False for page in info.pages:
pre_import_info = { page_info = {
"workspace_name": source_info["workspace_name"], "page_id": page.page_id,
"workspace_icon": source_info["workspace_icon"], "page_name": page.page_name,
"workspace_id": source_info["workspace_id"], "type": page.type,
"pages": pages, "parent_id": page.parent_id,
} "is_bound": page.page_id in exist_page_ids,
pre_import_info_list.append(pre_import_info) "page_icon": page.page_icon,
return {"notion_info": pre_import_info_list}, 200 }
pages.append(page_info)
except Exception as e:
raise e
return {"notion_info": {**workspace_info, "pages": pages}}, 200
class DataSourceNotionApi(Resource): class DataSourceNotionApi(Resource):
@@ -166,27 +201,25 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, workspace_id, page_id, page_type): def get(self, workspace_id, page_id, page_type):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id) workspace_id = str(workspace_id)
page_id = str(page_id) page_id = str(page_id)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
if not data_source_binding:
raise NotFound("Data source binding not found.")
extractor = NotionExtractor( extractor = NotionExtractor(
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=data_source_binding.access_token, notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
) )
@@ -211,10 +244,12 @@ class DataSourceNotionApi(Resource):
extract_settings = [] extract_settings = []
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"] workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page["page_id"],
"notion_page_type": page["type"], "notion_page_type": page["type"],

View File

@@ -20,7 +20,6 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
@@ -33,6 +32,7 @@ from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@@ -337,6 +337,15 @@ class DatasetApi(Resource):
location="json", location="json",
help="Invalid external knowledge api id.", help="Invalid external knowledge api id.",
) )
parser.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
args = parser.parse_args() args = parser.parse_args()
data = request.get_json() data = request.get_json()
@@ -387,7 +396,7 @@ class DatasetApi(Resource):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor or current_user.is_dataset_operator: if not (current_user.is_editor or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
try: try:
@@ -503,10 +512,12 @@ class DatasetIndexingEstimateApi(Resource):
notion_info_list = args["info_list"]["notion_info_list"] notion_info_list = args["info_list"]["notion_info_list"]
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"] workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page["page_id"],
"notion_page_type": page["type"], "notion_page_type": page["type"],
@@ -730,6 +741,19 @@ class DatasetApiDeleteApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
class DatasetEnableApiApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, status):
dataset_id_str = str(dataset_id)
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
return {"result": "success"}, 200
@console_ns.route("/datasets/api-base-info") @console_ns.route("/datasets/api-base-info")
class DatasetApiBaseUrlApi(Resource): class DatasetApiBaseUrlApi(Resource):
@api.doc("get_dataset_api_base_info") @api.doc("get_dataset_api_base_info")
@@ -758,7 +782,6 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TIDB_VECTOR | VectorType.TIDB_VECTOR
| VectorType.CHROMA | VectorType.CHROMA
| VectorType.PGVECTO_RS | VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB | VectorType.VIKINGDB
| VectorType.UPSTASH | VectorType.UPSTASH
): ):
@@ -785,6 +808,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TENCENT | VectorType.TENCENT
| VectorType.MATRIXONE | VectorType.MATRIXONE
| VectorType.CLICKZETTA | VectorType.CLICKZETTA
| VectorType.BAIDU
): ):
return { return {
"retrieval_method": [ "retrieval_method": [
@@ -814,7 +838,6 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TIDB_VECTOR | VectorType.TIDB_VECTOR
| VectorType.CHROMA | VectorType.CHROMA
| VectorType.PGVECTO_RS | VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB | VectorType.VIKINGDB
| VectorType.UPSTASH | VectorType.UPSTASH
): ):
@@ -839,6 +862,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.HUAWEI_CLOUD | VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE | VectorType.MATRIXONE
| VectorType.CLICKZETTA | VectorType.CLICKZETTA
| VectorType.BAIDU
): ):
return { return {
"retrieval_method": [ "retrieval_method": [

View File

@@ -1,8 +1,10 @@
import json
import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, cast from typing import Literal, cast
import sqlalchemy as sa
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@@ -53,6 +55,7 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@@ -209,13 +212,13 @@ class DatasetDocumentListApi(Resource):
if sort == "hit_count": if sort == "hit_count":
sub_query = ( sub_query = (
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.group_by(DocumentSegment.document_id) .group_by(DocumentSegment.document_id)
.subquery() .subquery()
) )
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position), sort_logic(Document.position),
) )
elif sort == "created_at": elif sort == "created_at":
@@ -542,6 +545,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"], "notion_page_type": data_source_info["type"],
@@ -716,7 +720,7 @@ class DocumentApi(DocumentResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without": elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@@ -1108,3 +1112,64 @@ class WebsiteDocumentSyncApi(DocumentResource):
DocumentService.sync_website_document(dataset_id, document) DocumentService.sync_website_document(dataset_id, document)
return {"result": "success"}, 200 return {"result": "success"}, 200
class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
)
if not log:
return {
"datasource_info": None,
"datasource_type": None,
"input_data": None,
"datasource_node_id": None,
}, 200
return {
"datasource_info": json.loads(log.datasource_info),
"datasource_type": log.datasource_type,
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DatasetInitApi, "/datasets/init")
api.add_resource(
DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
)
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
)
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

View File

@@ -71,3 +71,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error" error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}" description = "Delete child chunk index failed: {message}"
code = 500 code = 500
class PipelineNotFoundError(BaseHTTPException):
error_code = "pipeline_not_found"
description = "Pipeline not found."
code = 404

View File

@@ -148,7 +148,7 @@ class ExternalApiTemplateApi(Resource):
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor or current_user.is_dataset_operator: if not (current_user.is_editor or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)

View File

@@ -0,0 +1,362 @@
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
user = current_user
tenant_id = user.current_tenant_id
if not current_user.is_editor:
raise Forbidden()
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
oauth_config = DatasourceProviderService().get_oauth_client(
tenant_id=tenant_id,
datasource_provider_id=datasource_provider_id,
)
if not oauth_config:
raise ValueError(f"No OAuth Client Config for {provider_id}")
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
credential_id=credential_id,
)
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_config,
)
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
class DatasourceOAuthCallback(Resource):
@setup_required
def get(self, provider_id: str):
context_id = request.cookies.get("context_id") or request.args.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
datasource_provider_service = DatasourceProviderService()
oauth_client_params = datasource_provider_service.get_oauth_client(
tenant_id=tenant_id,
datasource_provider_id=datasource_provider_id,
)
if not oauth_client_params:
raise NotFound()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
oauth_handler = OAuthHandler()
oauth_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=datasource_provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credential_id = context.get("credential_id")
if credential_id:
datasource_provider_service.reauthorize_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=datasource_provider_id,
avatar_url=oauth_response.metadata.get("avatar_url") or None,
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
credential_id=context.get("credential_id"),
)
else:
datasource_provider_service.add_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=datasource_provider_id,
avatar_url=oauth_response.metadata.get("avatar_url") or None,
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
try:
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
class DatasourceAuthUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
credentials=args.get("credentials", {}),
name=args.get("name", None),
)
return {"result": "success"}, 201
class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200
class DatasourceHardCodeAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200
class DatasourceAuthOauthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
)
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
class DatasourceAuthDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
)
return {"result": "success"}, 200
class DatasourceUpdateProviderNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],
)
return {"result": "success"}, 200
api.add_resource(
DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
)
api.add_resource(
DatasourceOAuthCallback,
"/oauth/plugin/<path:provider_id>/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource/<path:provider_id>",
)
api.add_resource(
DatasourceAuthUpdateApi,
"/auth/plugin/datasource/<path:provider_id>/update",
)
api.add_resource(
DatasourceAuthDeleteApi,
"/auth/plugin/datasource/<path:provider_id>/delete",
)
api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
api.add_resource(
DatasourceHardCodeAuthListApi,
"/auth/plugin/datasource/default-list",
)
api.add_resource(
DatasourceAuthOauthCustomClient,
"/auth/plugin/datasource/<path:provider_id>/custom-client",
)
api.add_resource(
DatasourceAuthDefaultApi,
"/auth/plugin/datasource/<path:provider_id>/default",
)
api.add_resource(
DatasourceUpdateProviderNameApi,
"/auth/plugin/datasource/<path:provider_id>/update-name",
)

View File

@@ -0,0 +1,57 @@
from flask_restx import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DataSourceContentPreviewApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run datasource content preview
"""
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=True,
credential_id=args.get("credential_id"),
)
return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
)

View File

@@ -0,0 +1,164 @@
import logging
from flask import request
from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
knowledge_pipeline_publish_enabled,
setup_required,
)
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
class PipelineTemplateDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, template_id: str):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
return pipeline_template, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str):
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")
return {"data": template.yaml_content}, 200
class PublishCustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
return {"result": "success"}
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:template_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/customized/templates/<string:template_id>",
)
api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)

View File

@@ -0,0 +1,114 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restx import Resource, marshal, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class CreateRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(
name="",
description="",
icon_info=IconInfo(
icon="📙",
icon_background="#FFF4ED",
icon_type="emoji",
),
permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None,
yaml_content=args["yaml_content"],
)
try:
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return import_info, 201
class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="",
description="",
icon_info=IconInfo(
icon="📙",
icon_background="#FFF4ED",
icon_type="emoji",
),
permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None,
),
)
return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@@ -0,0 +1,389 @@
import logging
from typing import Any, NoReturn
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
- Dify has been property setup.
- The request user has logged in and initialized.
- The requested app is a workflow or a chat flow.
- The request user has the edit permission for the app.
"""
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)
return wrapper
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline)
if not workflow_exist:
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id,
page=args.page,
limit=args.limit,
)
return workflow_vars
@_api_prerequisite
def delete(self, pipeline: Pipeline):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(pipeline.id)
db.session.commit()
return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
]:
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
# with specific `node_id` in database, we still want to make the API separated. By disallowing
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
# we mitigate the risk that user of the API depending on the implementation detail of the API.
#
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
return node_vars
@_api_prerequisite
def delete(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(pipeline.id, node_id)
db.session.commit()
return Response("", 204)
class RagPipelineVariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
#
# Local File:
#
# {
# "type": "image",
# "transfer_method": "local_file",
# "url": "",
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
# }
#
# Remote File:
#
#
# {
# "type": "image",
# "transfer_method": "remote_url",
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
if new_name is None and raw_value is None:
return variable
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@_api_prerequisite
def delete(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
rag_pipeline_service = RagPipelineService()
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, pipeline_id={pipeline.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
else:
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
return draft_vars
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, pipeline: Pipeline):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
class RagPipelineEnvironmentVariableCollectionApi(Resource):
@_api_prerequisite
def get(self, pipeline: Pipeline):
"""
Get draft workflow
"""
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars = workflow.environment_variables
env_vars_list = []
for v in env_vars:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}
api.add_resource(
RagPipelineVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
)
api.add_resource(
RagPipelineNodeVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
)
api.add_resource(
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
)
api.add_resource(
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
)
api.add_resource(
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
)
api.add_resource(
RagPipelineEnvironmentVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
)

View File

@@ -0,0 +1,149 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
pipeline_id=args.get("pipeline_id"),
dataset_name=args.get("name"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class RagPipelineImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
return result.model_dump(mode="json"), 200
class RagPipelineExportApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args()
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=args["include_secret"] == "true"
)
return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,46 @@
from collections.abc import Callable
from functools import wraps
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from models.dataset import Pipeline
def get_rag_pipeline(
view: Callable | None = None,
):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
if not isinstance(current_user, Account):
raise ValueError("current_user is not an account")
pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
del kwargs["pipeline_id"]
pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.first()
)
if not pipeline:
raise PipelineNotFoundError()
kwargs["pipeline"] = pipeline
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@@ -20,6 +20,7 @@ from core.errors.error import (
QuotaExceededError, QuotaExceededError,
) )
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper from libs import helper
from libs.login import current_user from libs.login import current_user
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
@@ -82,6 +83,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
raise NotWorkflowAppError() raise NotWorkflowAppError()
assert current_user is not None assert current_user is not None
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"} return {"result": "success"}

View File

@@ -20,6 +20,7 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
setup_required, setup_required,
) )
from extensions.ext_database import db
from fields.file_fields import file_fields, upload_config_fields from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required from libs.login import login_required
from models import Account from models import Account
@@ -68,10 +69,11 @@ class FileApi(Resource):
if source not in ("datasets", None): if source not in ("datasets", None):
source = None source = None
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
try: try:
if not isinstance(current_user, Account): upload_file = FileService(db.engine).upload_file(
raise ValueError("Invalid user account")
upload_file = FileService.upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
mimetype=file.mimetype, mimetype=file.mimetype,
@@ -92,7 +94,7 @@ class FilePreviewApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
text = FileService.get_file_preview(file_id) text = FileService(db.engine).get_file_preview(file_id)
return {"content": text} return {"content": text}

View File

@@ -14,6 +14,7 @@ from controllers.common.errors import (
) )
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from models.account import Account from models.account import Account
from services.file_service import FileService from services.file_service import FileService
@@ -61,7 +62,7 @@ class RemoteFileUploadApi(Resource):
try: try:
user = cast(Account, current_user) user = cast(Account, current_user)
upload_file = FileService.upload_file( upload_file = FileService(db.engine).upload_file(
filename=file_info.filename, filename=file_info.filename,
content=content, content=content,
mimetype=file_info.mimetype, mimetype=file_info.mimetype,

View File

@@ -0,0 +1,35 @@
import logging
from flask_restx import Resource
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.schemas.schema_manager import SchemaManager
from libs.login import login_required
logger = logging.getLogger(__name__)
class SpecSchemaDefinitionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""
Get system JSON Schema definitions specification
Used for frontend component type mapping
"""
try:
schema_manager = SchemaManager()
schema_definitions = schema_manager.get_all_schema_definitions()
return schema_definitions, 200
except Exception:
logger.exception("Failed to get schema definitions from local registry")
# Return empty array as fallback
return [], 200
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")

View File

@@ -1,7 +1,7 @@
import json import json
import logging import logging
import requests import httpx
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from packaging import version from packaging import version
@@ -57,7 +57,11 @@ class VersionApi(Resource):
return result return result
try: try:
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
timeout=httpx.Timeout(connect=3, read=10),
)
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"] result["version"] = args["current_version"]

View File

@@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required from libs.login import login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService

View File

@@ -227,7 +227,7 @@ class WebappLogoWorkspaceApi(Resource):
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
try: try:
upload_file = FileService.upload_file( upload_file = FileService(db.engine).upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
mimetype=file.mimetype, mimetype=file.mimetype,

View File

@@ -279,3 +279,14 @@ def is_allow_transfer_owner(view: Callable[P, R]):
abort(403) abort(403)
return decorated return decorated
def knowledge_pipeline_publish_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)
return decorated

View File

@@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
import services import services
from controllers.common.errors import UnsupportedFileTypeError from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import files_ns from controllers.files import files_ns
from extensions.ext_database import db
from services.account_service import TenantService from services.account_service import TenantService
from services.file_service import FileService from services.file_service import FileService
@@ -28,7 +29,7 @@ class ImagePreviewApi(Resource):
return {"content": "Invalid request."}, 400 return {"content": "Invalid request."}, 400
try: try:
generator, mimetype = FileService.get_image_preview( generator, mimetype = FileService(db.engine).get_image_preview(
file_id=file_id, file_id=file_id,
timestamp=timestamp, timestamp=timestamp,
nonce=nonce, nonce=nonce,
@@ -57,7 +58,7 @@ class FilePreviewApi(Resource):
return {"content": "Invalid request."}, 400 return {"content": "Invalid request."}, 400
try: try:
generator, upload_file = FileService.get_file_generator_by_file_id( generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
file_id=file_id, file_id=file_id,
timestamp=args["timestamp"], timestamp=args["timestamp"],
nonce=args["nonce"], nonce=args["nonce"],
@@ -108,7 +109,7 @@ class WorkspaceWebappLogoApi(Resource):
raise NotFound("webapp logo is not found") raise NotFound("webapp logo is not found")
try: try:
generator, mimetype = FileService.get_public_image_preview( generator, mimetype = FileService(db.engine).get_public_image_preview(
webapp_logo_file_id, webapp_logo_file_id,
) )
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:

View File

@@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import files_ns from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db from extensions.ext_database import db as global_db
@files_ns.route("/tools/<uuid:file_id>.<string:extension>") @files_ns.route("/tools/<uuid:file_id>.<string:extension>")

View File

@@ -420,7 +420,12 @@ class PluginUploadFileRequestApi(Resource):
) )
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
# generate signed url # generate signed url
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id) url = get_signed_file_url_for_plugin(
filename=payload.filename,
mimetype=payload.mimetype,
tenant_id=tenant_model.id,
user_id=user_model.id,
)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()

View File

@@ -32,11 +32,20 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_model = ( user_model = (
session.query(EndUser) session.query(EndUser)
.where( .where(
EndUser.session_id == user_id, EndUser.id == user_id,
EndUser.tenant_id == tenant_id, EndUser.tenant_id == tenant_id,
) )
.first() .first()
) )
if not user_model:
user_model = (
session.query(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model: if not user_model:
user_model = EndUser( user_model = EndUser(
tenant_id=tenant_id, tenant_id=tenant_id,

View File

@@ -12,8 +12,9 @@ from controllers.common.errors import (
) )
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from fields.file_fields import build_file_model from fields.file_fields import build_file_model
from models.model import App, EndUser from models import App, EndUser
from services.file_service import FileService from services.file_service import FileService
@@ -52,7 +53,7 @@ class FileApi(Resource):
raise FilenameNotExistsError raise FilenameNotExistsError
try: try:
upload_file = FileService.upload_file( upload_file = FileService(db.engine).upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
mimetype=file.mimetype, mimetype=file.mimetype,

View File

@@ -26,7 +26,8 @@ from core.errors.error import (
) )
from core.helper.trace_id_helper import get_external_trace_id from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper from libs import helper
@@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"} return {"result": "success"}

View File

@@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
validate_dataset_token, validate_dataset_token,
) )
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user from libs.login import current_user
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import Dataset, DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService from services.tag_service import TagService

View File

@@ -124,7 +124,12 @@ class DocumentAddByTextApi(DatasetApiResource):
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
) )
upload_file = FileService.upload_text(text=str(text), text_name=str(name)) if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = { data_source = {
"type": "upload_file", "type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
@@ -134,6 +139,9 @@ class DocumentAddByTextApi(DatasetApiResource):
# validate args # validate args
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
if not current_user:
raise ValueError("current_user is required")
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
@@ -199,7 +207,11 @@ class DocumentUpdateByTextApi(DatasetApiResource):
name = args.get("name") name = args.get("name")
if text is None or name is None: if text is None or name is None:
raise ValueError("Both text and name must be strings.") raise ValueError("Both text and name must be strings.")
upload_file = FileService.upload_text(text=str(text), text_name=str(name)) if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = { data_source = {
"type": "upload_file", "type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
@@ -301,8 +313,9 @@ class DocumentAddByFileApi(DatasetApiResource):
if not isinstance(current_user, EndUser): if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account") raise ValueError("Invalid user account")
if not current_user:
upload_file = FileService.upload_file( raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
mimetype=file.mimetype, mimetype=file.mimetype,
@@ -390,10 +403,14 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not file.filename: if not file.filename:
raise FilenameNotExistsError raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
try: try:
if not isinstance(current_user, EndUser): upload_file = FileService(db.engine).upload_file(
raise ValueError("Invalid user account")
upload_file = FileService.upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
mimetype=file.mimetype, mimetype=file.mimetype,
@@ -571,7 +588,7 @@ class DocumentApi(DatasetApiResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without": elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@@ -604,7 +621,7 @@ class DocumentApi(DatasetApiResource):
} }
else: else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,

View File

@@ -47,3 +47,9 @@ class DatasetInUseError(BaseHTTPException):
error_code = "dataset_in_use" error_code = "dataset_in_use"
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
code = 409 code = 409
class PipelineRunError(BaseHTTPException):
error_code = "pipeline_run_error"
description = "An error occurred while running the pipeline."
code = 500

View File

@@ -133,7 +133,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
return 204 return 204
@service_api_ns.route("/datasets/metadata/built-in") @service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
@service_api_ns.doc("get_built_in_fields") @service_api_ns.doc("get_built_in_fields")
@service_api_ns.doc(description="Get all built-in metadata fields") @service_api_ns.doc(description="Get all built-in metadata fields")
@@ -143,7 +143,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
401: "Unauthorized - invalid API token", 401: "Unauthorized - invalid API token",
} }
) )
def get(self, tenant_id): def get(self, tenant_id, dataset_id):
"""Get all built-in metadata fields.""" """Get all built-in metadata fields."""
built_in_fields = MetadataService.get_built_in_fields() built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200 return {"fields": built_in_fields}, 200

View File

@@ -0,0 +1,242 @@
import string
import uuid
from collections.abc import Generator
from typing import Any
from flask import request
from flask_restx import reqparse
from flask_restx.reqparse import ParseResult, RequestParser
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.wraps import DatasetApiResource
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models.account import Account
from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
params={
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)"
}
)
@service_api_ns.doc(
responses={
200: "Datasource plugins retrieved successfully",
401: "Unauthorized - invalid API token",
}
)
def get(self, tenant_id: str, dataset_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
rag_pipeline_service: RagPipelineService = RagPipelineService()
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
)
return datasource_plugins, 200
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
class DatasourceNodeRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
body={
"inputs": "User input variables",
"datasource_type": "Datasource type, e.g. online_document",
"credential_id": "Credential ID",
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)",
}
)
@service_api_ns.doc(
responses={
200: "Datasource node run successfully",
401: "Unauthorized - invalid API token",
}
)
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
parser: RequestParser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=datasource_node_run_api_entity.inputs,
account=current_user,
datasource_type=datasource_node_run_api_entity.datasource_type,
is_published=datasource_node_run_api_entity.is_published,
credential_id=datasource_node_run_api_entity.credential_id,
)
)
)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
class PipelineRunApi(DatasetApiResource):
"""Resource for datasource node run."""
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
@service_api_ns.doc(
path={
"dataset_id": "Dataset ID",
}
)
@service_api_ns.doc(
body={
"inputs": "User input variables",
"datasource_type": "Datasource type, e.g. online_document",
"datasource_info_list": "Datasource info list",
"start_node_id": "Start node ID",
"is_published": "Whether to get published or draft datasource plugins "
"(true for published, false for draft, default: true)",
"streaming": "Whether to stream the response(streaming or blocking), default: streaming",
}
)
@service_api_ns.doc(
responses={
200: "Pipeline run successfully",
401: "Unauthorized - invalid API token",
}
)
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
parser: RequestParser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
parser.add_argument(
"response_mode",
type=str,
required=True,
choices=["streaming", "blocking"],
default="blocking",
location="json",
)
args: ParseResult = parser.parse_args()
if not isinstance(current_user, Account):
raise Forbidden()
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
try:
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
streaming=args.get("response_mode") == "streaming",
)
return helper.compact_generate_response(response)
except Exception as ex:
raise PipelineRunError(description=str(ex))
@service_api_ns.route("/datasets/pipeline/file-upload")
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
"""Resource for uploading a file to a knowledgebase pipeline."""
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
@service_api_ns.doc(
responses={
201: "File uploaded successfully",
400: "Bad request - no file or invalid file",
401: "Unauthorized - invalid API token",
413: "File too large",
415: "Unsupported file type",
}
)
def post(self, tenant_id: str):
"""Upload a file for use in conversations.
Accepts a single file upload via multipart/form-data.
"""
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.mimetype:
raise UnsupportedFileTypeError()
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@@ -193,6 +193,47 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def decorator(view: Callable[Concatenate[T, P], R]): def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
# First try to get from kwargs (explicit parameter)
dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except:
pass
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except:
pass
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
api_token = validate_and_get_api_token("dataset") api_token = validate_and_get_api_token("dataset")
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)

View File

@@ -11,6 +11,7 @@ from controllers.common.errors import (
) )
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from fields.file_fields import build_file_model from fields.file_fields import build_file_model
from services.file_service import FileService from services.file_service import FileService
@@ -68,7 +69,7 @@ class FileApi(WebApiResource):
source = None source = None
try: try:
upload_file = FileService.upload_file( upload_file = FileService(db.engine).upload_file(
filename=file.filename, filename=file.filename,
content=file.read(), content=file.read(),
mimetype=file.mimetype, mimetype=file.mimetype,

View File

@@ -261,6 +261,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
) )
# questions is a list of strings, not a list of Message objects
# so we can directly return it
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message not found") raise NotFound("Message not found")
except ConversationNotExistsError: except ConversationNotExistsError:

View File

@@ -14,6 +14,7 @@ from controllers.web import web_ns
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
from services.file_service import FileService from services.file_service import FileService
@@ -119,7 +120,7 @@ class RemoteFileUploadApi(WebApiResource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try: try:
upload_file = FileService.upload_file( upload_file = FileService(db.engine).upload_file(
filename=file_info.filename, filename=file_info.filename,
content=content, content=content,
mimetype=file_info.mimetype, mimetype=file_info.mimetype,

View File

@@ -21,6 +21,7 @@ from core.errors.error import (
QuotaExceededError, QuotaExceededError,
) )
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper from libs import helper
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@@ -112,6 +113,11 @@ class WorkflowTaskStopApi(WebApiResource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"} return {"result": "success"}

View File

@@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
tenant_id=tenant_id, tenant_id=tenant_id,
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [], dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source, return_resource=(
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
),
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback, hit_callback=hit_callback,
user_id=user_id, user_id=user_id,

View File

@@ -4,8 +4,8 @@ from typing import Any
from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from models.provider_ids import ModelProviderID
class ModelConfigManager: class ModelConfigManager:

View File

@@ -114,9 +114,9 @@ class VariableEntity(BaseModel):
hide: bool = False hide: bool = False
max_length: int | None = None max_length: int | None = None
options: Sequence[str] = Field(default_factory=list) options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
@field_validator("description", mode="before") @field_validator("description", mode="before")
@classmethod @classmethod
@@ -129,6 +129,16 @@ class VariableEntity(BaseModel):
return v or [] return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: str | None = None
placeholder: str | None = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel): class ExternalDataVariableEntity(BaseModel):
""" """
External Data Variable Entity. External Data Variable Entity.
@@ -288,7 +298,7 @@ class AppConfig(BaseModel):
tenant_id: str tenant_id: str
app_id: str app_id: str
app_mode: AppMode app_mode: AppMode
additional_features: AppAdditionalFeatures additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = [] variables: list[VariableEntity] = []
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None

View File

@@ -1,4 +1,6 @@
from core.app.app_config.entities import VariableEntity import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow from models.workflow import Workflow
@@ -20,3 +22,48 @@ class WorkflowVariablesConfigManager:
variables.append(VariableEntity.model_validate(variable)) variables.append(VariableEntity.model_validate(variable))
return variables return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return []
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
# get datasource node data
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part, None)
if value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
variables_map.pop(last_part, None)
all_second_step_variables = list(variables_map.values())
for item in all_second_step_variables:
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
variables.append(RagPipelineVariableEntity.model_validate(item))
return variables

View File

@@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if invoke_from == InvokeFrom.DEBUGGER: if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True app_config.additional_features.show_retrieve_source = True # type: ignore
workflow_run_id = str(uuid.uuid4()) workflow_run_id = str(uuid.uuid4())
# init application generate entity # init application generate entity
@@ -420,7 +420,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
db.session.refresh(conversation) db.session.refresh(conversation)
# get conversation dialogue count # get conversation dialogue count
self._dialogue_count = get_thread_messages_length(conversation.id) # NOTE: dialogue_count should not start from 0,
# because during the first conversation, dialogue_count should be 1.
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@@ -467,7 +469,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream, stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from), draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
) )
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

View File

@@ -1,11 +1,11 @@
import logging import logging
import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Workflow from models import Workflow
from models.enums import UserFrom from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType from models.workflow import ConversationVariable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -78,23 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
workflow_callbacks: list[WorkflowCallback] = [] if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
if dify_config.DEBUG: # Handle single iteration or single loop run
workflow_callbacks.append(WorkflowLoggingCallback()) graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, single_iteration_run=self.application_generate_entity.single_iteration_run,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), single_loop_run=self.application_generate_entity.single_loop_run,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
@@ -146,16 +136,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
) )
# init graph # init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
db.session.close() db.session.close()
# RUN WORKFLOW # RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry( workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id, tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id, app_id=self._workflow.app_id,
workflow_id=self._workflow.id, workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph, graph=graph,
graph_config=self._workflow.graph_dict, graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
@@ -167,11 +168,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth, call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool, variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
) )
generator = workflow_entry.run( generator = workflow_entry.run()
callbacks=workflow_callbacks,
)
for event in generator: for event in generator:
self._handle_event(workflow_entry, event) self._handle_event(workflow_entry, event)

View File

@@ -31,14 +31,9 @@ from core.app.entities.queue_entities import (
QueueMessageReplaceEvent, QueueMessageReplaceEvent,
QueueNodeExceptionEvent, QueueNodeExceptionEvent,
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent, QueueNodeRetryEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent, QueuePingEvent,
QueueRetrieverResourcesEvent, QueueRetrieverResourcesEvent,
QueueStopEvent, QueueStopEvent,
@@ -65,8 +60,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType from core.workflow.entities import GraphRuntimeState
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
@@ -387,9 +382,7 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_node_failed_events( def _handle_node_failed_events(
self, self,
event: Union[ event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
**kwargs, **kwargs,
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events.""" """Handle various node failure events."""
@@ -434,32 +427,6 @@ class AdvancedChatAppGenerateTaskPipeline:
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event( def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
@@ -751,8 +718,6 @@ class AdvancedChatAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event, QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event, QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event, QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events # Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event, QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event, QueueIterationNextEvent: self._handle_iteration_next_event,
@@ -800,8 +765,6 @@ class AdvancedChatAppGenerateTaskPipeline:
event, event,
( (
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent, QueueNodeExceptionEvent,
), ),
): ):
@@ -814,17 +777,6 @@ class AdvancedChatAppGenerateTaskPipeline:
) )
return return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# For unhandled events, we continue (original behavior) # For unhandled events, we continue (original behavior)
return return
@@ -848,11 +800,6 @@ class AdvancedChatAppGenerateTaskPipeline:
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
yield from self._handle_workflow_started_event(event) yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
yield from self._handle_text_chunk_event(
event, tts_publisher=tts_publisher, queue_message=queue_message
)
case QueueErrorEvent(): case QueueErrorEvent():
yield from self._handle_error_event(event) yield from self._handle_error_event(event)
break break

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig from core.file import File, FileUploadConfig
from core.workflow.nodes.enums import NodeType from core.workflow.enums import NodeType
from core.workflow.repositories.draft_variable_repository import ( from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver, DraftVariableSaver,
DraftVariableSaverFactory, DraftVariableSaverFactory,
@@ -14,6 +14,7 @@ from core.workflow.repositories.draft_variable_repository import (
) )
from factories import file_factory from factories import file_factory
from libs.orjson import orjson_dumps from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -44,9 +45,9 @@ class BaseAppGenerator:
mapping=v, mapping=v,
tenant_id=tenant_id, tenant_id=tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
), ),
strict_type_validation=strict_type_validation, strict_type_validation=strict_type_validation,
) )
@@ -59,9 +60,9 @@ class BaseAppGenerator:
mappings=v, mappings=v,
tenant_id=tenant_id, tenant_id=tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
), ),
) )
for k, v in user_inputs.items() for k, v in user_inputs.items()
@@ -182,8 +183,9 @@ class BaseAppGenerator:
@final @final
@staticmethod @staticmethod
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory: def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER: if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
def draft_var_saver_factory( def draft_var_saver_factory(
session: Session, session: Session,
@@ -200,6 +202,7 @@ class BaseAppGenerator:
node_type=node_type, node_type=node_type,
node_execution_id=node_execution_id, node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id, enclosing_node_id=enclosing_node_id,
user=account,
) )
else: else:

View File

@@ -127,6 +127,21 @@ class AppQueueManager:
stopped_cache_key = cls._generate_stopped_cache_key(task_id) stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1) redis_client.setex(stopped_cache_key, 600, 1)
@classmethod
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
"""
Set task stop flag without user permission check.
This method allows stopping workflows without user context.
:param task_id: The task ID to stop
:return:
"""
if not task_id:
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
def _is_stopped(self) -> bool: def _is_stopped(self) -> bool:
""" """
Check if task is stopped Check if task is stopped

View File

@@ -164,7 +164,9 @@ class ChatAppRunner(AppRunner):
config=app_config.dataset, config=app_config.dataset,
query=query, query=query,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source, show_retrieve_source=(
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
),
hit_callback=hit_callback, hit_callback=hit_callback,
memory=memory, memory=memory,
message_id=message.id, message_id=message.id,

View File

@@ -1,7 +1,7 @@
import time import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Union, cast from typing import Any, Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -16,14 +16,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent, QueueLoopStartEvent,
QueueNodeExceptionEvent, QueueNodeExceptionEvent,
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent, QueueNodeRetryEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AgentLogStreamResponse, AgentLogStreamResponse,
@@ -36,24 +31,23 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse, NodeFinishStreamResponse,
NodeRetryStreamResponse, NodeRetryStreamResponse,
NodeStartStreamResponse, NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
) )
from core.file import FILE_MODEL_IDENTITY, File from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models import ( from models import (
Account, Account,
EndUser, EndUser,
) )
from services.variable_truncator import VariableTruncator
class WorkflowResponseConverter: class WorkflowResponseConverter:
@@ -65,6 +59,7 @@ class WorkflowResponseConverter:
): ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._user = user self._user = user
self._truncator = VariableTruncator.default()
def workflow_start_to_stream_response( def workflow_start_to_stream_response(
self, self,
@@ -156,7 +151,8 @@ class WorkflowResponseConverter:
title=workflow_node_execution.title, title=workflow_node_execution.title,
index=workflow_node_execution.index, index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs, inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
created_at=int(workflow_node_execution.created_at.timestamp()), created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
@@ -171,11 +167,19 @@ class WorkflowResponseConverter:
# extras logic # extras logic
if event.node_type == NodeType.TOOL: if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras["icon"] = ToolManager.get_tool_icon( response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id, tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type, provider_type=ToolProviderType(event.provider_type),
provider_id=node_data.provider_id, provider_id=event.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
) )
return response return response
@@ -183,11 +187,7 @@ class WorkflowResponseConverter:
def workflow_node_finish_to_stream_response( def workflow_node_finish_to_stream_response(
self, self,
*, *,
event: QueueNodeSucceededEvent event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> NodeFinishStreamResponse | None: ) -> NodeFinishStreamResponse | None:
@@ -210,9 +210,12 @@ class WorkflowResponseConverter:
index=workflow_node_execution.index, index=workflow_node_execution.index,
title=workflow_node_execution.title, title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs, inputs=workflow_node_execution.get_response_inputs(),
process_data=workflow_node_execution.process_data, inputs_truncated=workflow_node_execution.inputs_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status, status=workflow_node_execution.status,
error=workflow_node_execution.error, error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time, elapsed_time=workflow_node_execution.elapsed_time,
@@ -221,9 +224,6 @@ class WorkflowResponseConverter:
finished_at=int(workflow_node_execution.finished_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id, iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id, loop_id=event.in_loop_id,
), ),
@@ -255,9 +255,12 @@ class WorkflowResponseConverter:
index=workflow_node_execution.index, index=workflow_node_execution.index,
title=workflow_node_execution.title, title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs, inputs=workflow_node_execution.get_response_inputs(),
process_data=workflow_node_execution.process_data, inputs_truncated=workflow_node_execution.inputs_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs), process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status, status=workflow_node_execution.status,
error=workflow_node_execution.error, error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time, elapsed_time=workflow_node_execution.elapsed_time,
@@ -275,50 +278,6 @@ class WorkflowResponseConverter:
), ),
) )
def workflow_parallel_branch_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunStartedEvent,
) -> ParallelBranchStartStreamResponse:
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=int(time.time()),
),
)
def workflow_parallel_branch_finished_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
def workflow_iteration_start_to_stream_response( def workflow_iteration_start_to_stream_response(
self, self,
*, *,
@@ -326,6 +285,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str, workflow_execution_id: str,
event: QueueIterationStartEvent, event: QueueIterationStartEvent,
) -> IterationNodeStartStreamResponse: ) -> IterationNodeStartStreamResponse:
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
return IterationNodeStartStreamResponse( return IterationNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_execution_id, workflow_run_id=workflow_execution_id,
@@ -333,13 +293,12 @@ class WorkflowResponseConverter:
id=event.node_id, id=event.node_id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_title,
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
inputs=event.inputs or {}, inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {}, metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
), ),
) )
@@ -357,15 +316,10 @@ class WorkflowResponseConverter:
id=event.node_id, id=event.node_id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_title,
index=event.index, index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
), ),
) )
@@ -377,6 +331,11 @@ class WorkflowResponseConverter:
event: QueueIterationCompletedEvent, event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse: ) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter() json_converter = WorkflowRuntimeTypeConverter()
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
json_converter.to_json_encodable(event.outputs) or {}
)
return IterationNodeCompletedStreamResponse( return IterationNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_execution_id, workflow_run_id=workflow_execution_id,
@@ -384,28 +343,29 @@ class WorkflowResponseConverter:
id=event.node_id, id=event.node_id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_title,
outputs=json_converter.to_json_encodable(event.outputs), outputs=new_outputs,
outputs_truncated=outputs_truncated,
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
inputs=event.inputs or {}, inputs=new_inputs,
inputs_truncated=inputs_truncated,
status=WorkflowNodeExecutionStatus.SUCCEEDED status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None if event.error is None
else WorkflowNodeExecutionStatus.FAILED, else WorkflowNodeExecutionStatus.FAILED,
error=None, error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(), elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata, execution_metadata=event.metadata,
finished_at=int(time.time()), finished_at=int(time.time()),
steps=event.steps, steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
), ),
) )
def workflow_loop_start_to_stream_response( def workflow_loop_start_to_stream_response(
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse: ) -> LoopNodeStartStreamResponse:
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
return LoopNodeStartStreamResponse( return LoopNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_execution_id, workflow_run_id=workflow_execution_id,
@@ -413,10 +373,11 @@ class WorkflowResponseConverter:
id=event.node_id, id=event.node_id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_title,
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
inputs=event.inputs or {}, inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {}, metadata=event.metadata or {},
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
@@ -437,15 +398,16 @@ class WorkflowResponseConverter:
id=event.node_id, id=event.node_id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_title,
index=event.index, index=event.index,
pre_loop_output=event.output, # The `pre_loop_output` field is not utilized by the frontend.
# Previously, it was assigned the value of `event.output`.
pre_loop_output={},
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id, parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
), ),
) )
@@ -456,6 +418,11 @@ class WorkflowResponseConverter:
workflow_execution_id: str, workflow_execution_id: str,
event: QueueLoopCompletedEvent, event: QueueLoopCompletedEvent,
) -> LoopNodeCompletedStreamResponse: ) -> LoopNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
json_converter.to_json_encodable(event.outputs) or {}
)
return LoopNodeCompletedStreamResponse( return LoopNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_execution_id, workflow_run_id=workflow_execution_id,
@@ -463,17 +430,19 @@ class WorkflowResponseConverter:
id=event.node_id, id=event.node_id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_title,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs), outputs=new_outputs,
outputs_truncated=outputs_truncated,
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
inputs=event.inputs or {}, inputs=new_inputs,
inputs_truncated=inputs_truncated,
status=WorkflowNodeExecutionStatus.SUCCEEDED status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None if event.error is None
else WorkflowNodeExecutionStatus.FAILED, else WorkflowNodeExecutionStatus.FAILED,
error=None, error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(), elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata, execution_metadata=event.metadata,
finished_at=int(time.time()), finished_at=int(time.time()),
steps=event.steps, steps=event.steps,

View File

@@ -124,7 +124,9 @@ class CompletionAppRunner(AppRunner):
config=dataset_config, config=dataset_config,
query=query or "", query=query or "",
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source, show_retrieve_source=app_config.additional_features.show_retrieve_source
if app_config.additional_features
else False,
hit_callback=hit_callback, hit_callback=hit_callback,
message_id=message.id, message_id=message.id,
inputs=inputs, inputs=inputs,

View File

@@ -0,0 +1,95 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk

View File

@@ -0,0 +1,66 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.dataset import Pipeline
from models.model import AppMode
from models.workflow import Workflow
class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass
class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
)
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -0,0 +1,856 @@
import contextvars
import datetime
import json
import logging
import secrets
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories.factory import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
class PipelineGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
with Session(db.engine, expire_on_commit=False) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents: list[Document] = []
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=dataset.id,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=dataset.chunk_structure,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# run in child thread
rag_pipeline_invoke_entities = []
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = args.get("original_document_id") or None
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id,
input_data=inputs,
pipeline_id=pipeline.id,
created_by=user.id,
)
db.session.add(document_pipeline_execution_log)
db.session.commit()
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
original_document_id=args.get("original_document_id"),
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=[],
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
workflow_execution_id=workflow_run_id,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER or is_retry:
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
context=contextvars.copy_context(),
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
rag_pipeline_invoke_entities.append(
RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
workflow_id=workflow.id,
streaming=streaming,
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
)
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}
def _generate(
self,
*,
flask_app: Flask,
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param pipeline: Pipeline
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)
context = contextvars.copy_context()
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"queue_manager": queue_manager,
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
},
)
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
user,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=dataset.id,
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def single_loop_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
batch=args.get("batch", ""),
document_id=args.get("document_id"),
dataset_id=dataset.id,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
)
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
# workflow app
runner = PipelineRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
)
runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response(
self,
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
"Fails to process generate task pipeline, task_id: %r",
application_generate_entity.task_id,
)
raise e
def _build_document(
self,
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info.get("name", "untitled")
elif datasource_type == "online_document":
name = datasource_info.get("page", {}).get("page_name", "untitled")
elif datasource_type == "website_crawl":
name = datasource_info.get("title", "untitled")
elif datasource_type == "online_drive":
name = datasource_info.get("name", "untitled")
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type=datasource_type,
data_source_info=json.dumps(datasource_info),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form,
)
doc_metadata = {}
if built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: datasource_type,
}
if doc_metadata:
document.doc_metadata = doc_metadata
return document
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files: list[Mapping[str, Any]] = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
credential_id=datasource_node_data.get("credential_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
for datasource_info in datasource_info_list:
if datasource_info.get("id") and datasource_info.get("type") == "folder":
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("id", ""),
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": datasource_info.get("id", ""),
"name": datasource_info.get("name", "untitled"),
"bucket": datasource_info.get("bucket", None),
}
)
return all_files
else:
return datasource_info_list
def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
bucket: str | None,
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None,
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
next_page_parameters=next_page_parameters,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
for result in result_generator:
for files in result.result:
for file in files.files:
if file.type == "folder":
self._get_files_in_folder(
datasource_runtime,
file.id,
bucket,
user_id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": file.id,
"name": file.name,
"bucket": bucket,
}
)
is_truncated = files.is_truncated
next_page_parameters = files.next_page_parameters
if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
)

View File

@@ -0,0 +1,45 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
class PipelineQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@@ -0,0 +1,263 @@
import logging
import time
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class PipelineRunner(WorkflowBasedAppRunner):
"""
Pipeline Application Runner
"""
def __init__(
self,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
document_id=self.application_generate_entity.document_id,
original_document_id=self.application_generate_entity.original_document_id,
batch=self.application_generate_entity.batch,
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
)
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_rag_pipeline_graph(
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
)
generator = workflow_entry.run()
for event in generator:
self._update_document_status(
event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.first()
)
# return workflow
return workflow
def _init_rag_pipeline_graph(
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
) -> Graph:
"""
Init pipeline graph
"""
graph_config = workflow.graph_dict
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# nodes = graph_config.get("nodes", [])
# edges = graph_config.get("edges", [])
# real_run_nodes = []
# real_edges = []
# exclude_node_ids = []
# for node in nodes:
# node_id = node.get("id")
# node_type = node.get("data", {}).get("type", "")
# if node_type == "datasource":
# if start_node_id != node_id:
# exclude_node_ids.append(node_id)
# continue
# real_run_nodes.append(node)
# for edge in edges:
# if edge.get("source") in exclude_node_ids:
# continue
# real_edges.append(edge)
# graph_config = dict(graph_config)
# graph_config["nodes"] = real_run_nodes
# graph_config["edges"] = real_edges
# init graph
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None:
"""
Update document status
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
db.session.add(document)
db.session.commit()

View File

@@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
call_depth: int, call_depth: int,
workflow_thread_pool_id: str | None,
) -> Generator[Mapping | str, None, None]: ... ) -> Generator[Mapping | str, None, None]: ...
@overload @overload
@@ -67,7 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[False], streaming: Literal[False],
call_depth: int, call_depth: int,
workflow_thread_pool_id: str | None,
) -> Mapping[str, Any]: ... ) -> Mapping[str, Any]: ...
@overload @overload
@@ -81,7 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool, streaming: bool,
call_depth: int, call_depth: int,
workflow_thread_pool_id: str | None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate( def generate(
@@ -94,7 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or [] files: Sequence[Mapping[str, Any]] = args.get("files") or []
@@ -186,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
) )
def _generate( def _generate(
@@ -200,7 +195,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True, streaming: bool = True,
workflow_thread_pool_id: str | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
@@ -214,7 +208,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_execution_repository: repository for workflow execution :param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution :param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream :param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
""" """
# init queue manager # init queue manager
queue_manager = WorkflowAppQueueManager( queue_manager = WorkflowAppQueueManager(
@@ -237,16 +230,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
"application_generate_entity": application_generate_entity, "application_generate_entity": application_generate_entity,
"queue_manager": queue_manager, "queue_manager": queue_manager,
"context": context, "context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader, "variable_loader": variable_loader,
}, },
) )
worker_thread.start() worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory( draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
invoke_from,
)
# return response or stream generator # return response or stream generator
response = self._handle_response( response = self._handle_response(
@@ -434,8 +424,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
context: contextvars.Context, context: contextvars.Context,
variable_loader: VariableLoader, variable_loader: VariableLoader,
workflow_thread_pool_id: str | None = None, ) -> None:
):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@@ -444,7 +433,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
:return: :return:
""" """
with preserve_flask_contexts(flask_app, context_vars=context): with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar( workflow = session.scalar(
@@ -474,7 +462,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
runner = WorkflowAppRunner( runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader, variable_loader=variable_loader,
workflow=workflow, workflow=workflow,
system_user_id=system_user_id, system_user_id=system_user_id,

View File

@@ -1,7 +1,7 @@
import logging import logging
import time
from typing import cast from typing import cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import Workflow, WorkflowType from models.workflow import Workflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,7 +32,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
variable_loader: VariableLoader, variable_loader: VariableLoader,
workflow_thread_pool_id: str | None = None,
workflow: Workflow, workflow: Workflow,
system_user_id: str, system_user_id: str,
): ):
@@ -41,7 +41,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
) )
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow self._workflow = workflow
self._sys_user_id = system_user_id self._sys_user_id = system_user_id
@@ -52,24 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)
workflow_callbacks: list[WorkflowCallback] = [] # if only single iteration or single loop run is requested
if dify_config.DEBUG: if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
workflow_callbacks.append(WorkflowLoggingCallback()) graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, single_iteration_run=self.application_generate_entity.single_iteration_run,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, single_loop_run=self.application_generate_entity.single_loop_run,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
@@ -92,15 +79,27 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
conversation_variables=[], conversation_variables=[],
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph # init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict) graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
# RUN WORKFLOW # RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry( workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id, tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id, app_id=self._workflow.app_id,
workflow_id=self._workflow.id, workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph, graph=graph,
graph_config=self._workflow.graph_dict, graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
@@ -112,10 +111,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth, call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool, variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id, graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
) )
generator = workflow_entry.run(callbacks=workflow_callbacks) generator = workflow_entry.run()
for event in generator: for event in generator:
self._handle_event(workflow_entry, event) self._handle_event(workflow_entry, event)

View File

@@ -2,7 +2,7 @@ import logging
import time import time
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Union from typing import Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage, MessageQueueMessage,
QueueAgentLogEvent, QueueAgentLogEvent,
QueueErrorEvent, QueueErrorEvent,
@@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent, QueueLoopStartEvent,
QueueNodeExceptionEvent, QueueNodeExceptionEvent,
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent, QueueNodeRetryEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent, QueuePingEvent,
QueueStopEvent, QueueStopEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
@@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from core.workflow.entities import GraphRuntimeState, WorkflowExecution
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@@ -349,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_node_failed_events( def _handle_node_failed_events(
self, self,
event: Union[ event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
**kwargs, **kwargs,
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events.""" """Handle various node failure events."""
@@ -370,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
if node_failed_response: if node_failed_response:
yield node_failed_response yield node_failed_response
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event( def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
@@ -617,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event, QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event, QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event, QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events # Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event, QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event, QueueIterationNextEvent: self._handle_iteration_next_event,
@@ -633,7 +599,7 @@ class WorkflowAppGenerateTaskPipeline:
def _dispatch_event( def _dispatch_event(
self, self,
event: Any, event: AppQueueEvent,
*, *,
graph_runtime_state: GraphRuntimeState | None = None, graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None, tts_publisher: AppGeneratorTTSPublisher | None = None,
@@ -660,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
event, event,
( (
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent, QueueNodeExceptionEvent,
), ),
): ):
@@ -674,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
) )
return return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle workflow failed and stop events with isinstance check # Handle workflow failed and stop events with isinstance check
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)): if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events( yield from self._handle_workflow_failed_and_stop_events(

View File

@@ -1,7 +1,9 @@
import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
QueueAgentLogEvent, QueueAgentLogEvent,
@@ -13,14 +15,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent, QueueLoopStartEvent,
QueueNodeExceptionEvent, QueueNodeExceptionEvent,
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent, QueueNodeRetryEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent, QueueRetrieverResourcesEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
QueueWorkflowFailedEvent, QueueWorkflowFailedEvent,
@@ -28,42 +25,39 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
) )
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.graph import Graph
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_events import (
AgentLogEvent,
GraphEngineEvent, GraphEngineEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
IterationRunFailedEvent, NodeRunAgentLogEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeInIterationFailedEvent,
NodeInLoopFailedEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetrieverResourceEvent, NodeRunRetrieverResourceEvent,
NodeRunRetryEvent, NodeRunRetryEvent,
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
NodeRunSucceededEvent, NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow from models.workflow import Workflow
@@ -79,7 +73,14 @@ class WorkflowBasedAppRunner:
self._variable_loader = variable_loader self._variable_loader = variable_loader
self._app_id = app_id self._app_id = app_id
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
) -> Graph:
""" """
Init graph Init graph
""" """
@@ -91,22 +92,109 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list") raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=tenant_id or "",
app_id=self._app_id,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph # init graph
graph = Graph.init(graph_config=graph_config) graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
if not graph: if not graph:
raise ValueError("graph not found in workflow") raise ValueError("graph not found in workflow")
return graph return graph
def _get_graph_and_variable_pool_of_single_iteration( def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).
Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)
Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
def _get_graph_and_variable_pool_for_single_node_run(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict, user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single iteration Get graph and variable pool for single node execution (iteration or loop).
Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')
Returns:
A tuple containing (graph, variable_pool)
""" """
# fetch workflow graph # fetch workflow graph
graph_config = workflow.graph_dict graph_config = workflow.graph_dict
@@ -124,18 +212,22 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list") raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration # filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [ node_configs = [
node node
for node in graph_config.get("nodes", []) for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
] ]
graph_config["nodes"] = node_configs graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs] node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration # filter edges only in the specified node type
edge_configs = [ edge_configs = [
edge edge
for edge in graph_config.get("edges", []) for edge in graph_config.get("edges", [])
@@ -145,37 +237,50 @@ class WorkflowBasedAppRunner:
graph_config["edges"] = edge_configs graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph # init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph: if not graph:
raise ValueError("graph not found in workflow") raise ValueError("graph not found in workflow")
# fetch node config from node id # fetch node config from node id
iteration_node_config = None target_node_config = None
for node in node_configs: for node in node_configs:
if node.get("id") == node_id: if node.get("id") == node_id:
iteration_node_config = node target_node_config = node
break break
if not iteration_node_config: if not target_node_config:
raise ValueError("iteration node id not found in workflow graph") raise ValueError(f"{node_type_label} node id not found in workflow graph")
# Get node class # Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type")) node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1") node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool # Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = VariablePool( variable_pool = graph_runtime_state.variable_pool
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try: try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config graph_config=workflow.graph_dict, config=target_node_config
) )
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
@@ -196,102 +301,44 @@ class WorkflowBasedAppRunner:
return graph, variable_pool return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop( def _get_graph_and_variable_pool_of_single_loop(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict, user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single loop Get variable pool of single loop
""" """
# fetch workflow graph return self._get_graph_and_variable_pool_for_single_node_run(
graph_config = workflow.graph_dict workflow=workflow,
if not graph_config: node_id=node_id,
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break
if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")
# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs, user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
) )
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
""" """
Handle event Handle event
@@ -310,39 +357,32 @@ class WorkflowBasedAppRunner:
) )
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, NodeRunRetryEvent): elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result node_run_result = event.node_run_result
inputs: Mapping[str, Any] | None = {} inputs = node_run_result.inputs
process_data: Mapping[str, Any] | None = {} process_data = node_run_result.process_data
outputs: Mapping[str, Any] | None = {} outputs = node_run_result.outputs
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {} execution_metadata = node_run_result.metadata
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event( self._publish_event(
QueueNodeRetryEvent( QueueNodeRetryEvent(
node_execution_id=event.id, node_execution_id=event.id,
node_id=event.node_id, node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type, node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at, start_at=event.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs=outputs, outputs=outputs,
error=event.error, error=event.error,
execution_metadata=execution_metadata, execution_metadata=execution_metadata,
retry_index=event.retry_index, retry_index=event.retry_index,
provider_type=event.provider_type,
provider_id=event.provider_id,
) )
) )
elif isinstance(event, NodeRunStartedEvent): elif isinstance(event, NodeRunStartedEvent):
@@ -350,44 +390,29 @@ class WorkflowBasedAppRunner:
QueueNodeStartedEvent( QueueNodeStartedEvent(
node_execution_id=event.id, node_execution_id=event.id,
node_id=event.node_id, node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type, node_type=event.node_type,
node_data=event.node_data, start_at=event.start_at,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy, agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
) )
) )
elif isinstance(event, NodeRunSucceededEvent): elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.route_node_state.node_run_result node_run_result = event.node_run_result
if node_run_result: inputs = node_run_result.inputs
inputs = node_run_result.inputs process_data = node_run_result.process_data
process_data = node_run_result.process_data outputs = node_run_result.outputs
outputs = node_run_result.outputs execution_metadata = node_run_result.metadata
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event( self._publish_event(
QueueNodeSucceededEvent( QueueNodeSucceededEvent(
node_execution_id=event.id, node_execution_id=event.id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
node_data=event.node_data, start_at=event.start_at,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs=outputs, outputs=outputs,
@@ -396,34 +421,18 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
) )
) )
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
self._publish_event( self._publish_event(
QueueNodeFailedEvent( QueueNodeFailedEvent(
node_execution_id=event.id, node_execution_id=event.id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
node_data=event.node_data, start_at=event.start_at,
parallel_id=event.parallel_id, inputs=event.node_run_result.inputs,
parallel_start_node_id=event.parallel_start_node_id, process_data=event.node_run_result.process_data,
parent_parallel_id=event.parent_parallel_id, outputs=event.node_run_result.outputs,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, error=event.node_run_result.error or "Unknown error",
start_at=event.route_node_state.start_at, execution_metadata=event.node_run_result.metadata,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
) )
@@ -434,93 +443,21 @@ class WorkflowBasedAppRunner:
node_execution_id=event.id, node_execution_id=event.id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
node_data=event.node_data, start_at=event.start_at,
parallel_id=event.parallel_id, inputs=event.node_run_result.inputs,
parallel_start_node_id=event.parallel_start_node_id, process_data=event.node_run_result.process_data,
parent_parallel_id=event.parent_parallel_id, outputs=event.node_run_result.outputs,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, error=event.node_run_result.error or "Unknown error",
start_at=event.route_node_state.start_at, execution_metadata=event.node_run_result.metadata,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
) )
) )
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeInLoopFailedEvent):
self._publish_event(
QueueNodeInLoopFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent): elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event( self._publish_event(
QueueTextChunkEvent( QueueTextChunkEvent(
text=event.chunk_content, text=event.chunk,
from_variable_selector=event.from_variable_selector, from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
) )
@@ -533,10 +470,10 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
) )
) )
elif isinstance(event, AgentLogEvent): elif isinstance(event, NodeRunAgentLogEvent):
self._publish_event( self._publish_event(
QueueAgentLogEvent( QueueAgentLogEvent(
id=event.id, id=event.message_id,
label=event.label, label=event.label,
node_execution_id=event.node_execution_id, node_execution_id=event.node_execution_id,
parent_id=event.parent_id, parent_id=event.parent_id,
@@ -547,51 +484,13 @@ class WorkflowBasedAppRunner:
node_id=event.node_id, node_id=event.node_id,
) )
) )
elif isinstance(event, ParallelBranchRunStartedEvent): elif isinstance(event, NodeRunIterationStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, IterationRunStartedEvent):
self._publish_event( self._publish_event(
QueueIterationStartEvent( QueueIterationStartEvent(
node_execution_id=event.iteration_id, node_execution_id=event.id,
node_id=event.iteration_node_id, node_id=event.node_id,
node_type=event.iteration_node_type, node_type=event.node_type,
node_data=event.iteration_node_data, node_title=event.node_title,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at, start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs, inputs=event.inputs,
@@ -599,55 +498,41 @@ class WorkflowBasedAppRunner:
metadata=event.metadata, metadata=event.metadata,
) )
) )
elif isinstance(event, IterationRunNextEvent): elif isinstance(event, NodeRunIterationNextEvent):
self._publish_event( self._publish_event(
QueueIterationNextEvent( QueueIterationNextEvent(
node_execution_id=event.iteration_id, node_execution_id=event.id,
node_id=event.iteration_node_id, node_id=event.node_id,
node_type=event.iteration_node_type, node_type=event.node_type,
node_data=event.iteration_node_data, node_title=event.node_title,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index, index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output, output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
) )
) )
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
self._publish_event( self._publish_event(
QueueIterationCompletedEvent( QueueIterationCompletedEvent(
node_execution_id=event.iteration_id, node_execution_id=event.id,
node_id=event.iteration_node_id, node_id=event.node_id,
node_type=event.iteration_node_type, node_type=event.node_type,
node_data=event.iteration_node_data, node_title=event.node_title,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at, start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs, inputs=event.inputs,
outputs=event.outputs, outputs=event.outputs,
metadata=event.metadata, metadata=event.metadata,
steps=event.steps, steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None, error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
) )
) )
elif isinstance(event, LoopRunStartedEvent): elif isinstance(event, NodeRunLoopStartedEvent):
self._publish_event( self._publish_event(
QueueLoopStartEvent( QueueLoopStartEvent(
node_execution_id=event.loop_id, node_execution_id=event.id,
node_id=event.loop_node_id, node_id=event.node_id,
node_type=event.loop_node_type, node_type=event.node_type,
node_data=event.loop_node_data, node_title=event.node_title,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at, start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs, inputs=event.inputs,
@@ -655,42 +540,32 @@ class WorkflowBasedAppRunner:
metadata=event.metadata, metadata=event.metadata,
) )
) )
elif isinstance(event, LoopRunNextEvent): elif isinstance(event, NodeRunLoopNextEvent):
self._publish_event( self._publish_event(
QueueLoopNextEvent( QueueLoopNextEvent(
node_execution_id=event.loop_id, node_execution_id=event.id,
node_id=event.loop_node_id, node_id=event.node_id,
node_type=event.loop_node_type, node_type=event.node_type,
node_data=event.loop_node_data, node_title=event.node_title,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index, index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_loop_output, output=event.pre_loop_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
) )
) )
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)): elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
self._publish_event( self._publish_event(
QueueLoopCompletedEvent( QueueLoopCompletedEvent(
node_execution_id=event.loop_id, node_execution_id=event.id,
node_id=event.loop_node_id, node_id=event.node_id,
node_type=event.loop_node_type, node_type=event.node_type,
node_data=event.loop_node_data, node_title=event.node_title,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at, start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs, inputs=event.inputs,
outputs=event.outputs, outputs=event.outputs,
metadata=event.metadata, metadata=event.metadata,
steps=event.steps, steps=event.steps,
error=event.error if isinstance(event, LoopRunFailedEvent) else None, error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
) )
) )

View File

@@ -1,9 +1,12 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from enum import StrEnum from enum import StrEnum
from typing import Any from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
@@ -35,6 +38,7 @@ class InvokeFrom(StrEnum):
# DEBUGGER indicates that this invocation is from # DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page. # the workflow (or chatflow) edit page.
DEBUGGER = "debugger" DEBUGGER = "debugger"
PUBLISHED = "published"
@classmethod @classmethod
def value_of(cls, value: str): def value_of(cls, value: str):
@@ -113,8 +117,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict) extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance # tracing instance
# Using Any to avoid circular import with TraceQueueManager trace_manager: Optional["TraceQueueManager"] = None
trace_manager: Any | None = None
class EasyUIBasedAppGenerateEntity(AppGenerateEntity): class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
@@ -240,3 +243,34 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
inputs: dict inputs: dict
single_loop_run: SingleLoopRunEntity | None = None single_loop_run: SingleLoopRunEntity | None = None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""
# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: str | None = None
original_document_id: str | None = None
start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references
from core.ops.ops_trace_manager import TraceQueueManager
# Rebuild models that use forward references
AppGenerateEntity.model_rebuild()
EasyUIBasedAppGenerateEntity.model_rebuild()
ConversationAppGenerateEntity.model_rebuild()
ChatAppGenerateEntity.model_rebuild()
CompletionAppGenerateEntity.model_rebuild()
AgentChatAppGenerateEntity.model_rebuild()
AdvancedChatAppGenerateEntity.model_rebuild()
WorkflowAppGenerateEntity.model_rebuild()
RagPipelineGenerateEntity.model_rebuild()

View File

@@ -3,15 +3,13 @@ from datetime import datetime
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class QueueEvent(StrEnum): class QueueEvent(StrEnum):
@@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
ANNOTATION_REPLY = "annotation_reply" ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought" AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file" MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log" AGENT_LOG = "agent_log"
ERROR = "error" ERROR = "error"
PING = "ping" PING = "ping"
@@ -80,21 +75,13 @@ class QueueIterationStartEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime start_at: datetime
node_run_index: int node_run_index: int
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None predecessor_node_id: str | None = None
metadata: Mapping[str, Any] | None = None metadata: Mapping[str, object] = Field(default_factory=dict)
class QueueIterationNextEvent(AppQueueEvent): class QueueIterationNextEvent(AppQueueEvent):
@@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_run_index: int node_run_index: int
output: Any | None = None # output for the current iteration output: Any = None # output for the current iteration
duration: float | None = None
class QueueIterationCompletedEvent(AppQueueEvent): class QueueIterationCompletedEvent(AppQueueEvent):
@@ -134,21 +110,13 @@ class QueueIterationCompletedEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime start_at: datetime
node_run_index: int node_run_index: int
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, Any] | None = None metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0 steps: int = 0
error: str | None = None error: str | None = None
@@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_title: str
parallel_id: str | None = None parallel_id: str | None = None
"""parallel id if node is in parallel""" """parallel id if node is in parallel"""
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
@@ -175,9 +143,9 @@ class QueueLoopStartEvent(AppQueueEvent):
start_at: datetime start_at: datetime
node_run_index: int node_run_index: int
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None predecessor_node_id: str | None = None
metadata: Mapping[str, Any] | None = None metadata: Mapping[str, object] = Field(default_factory=dict)
class QueueLoopNextEvent(AppQueueEvent): class QueueLoopNextEvent(AppQueueEvent):
@@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_title: str
parallel_id: str | None = None parallel_id: str | None = None
"""parallel id if node is in parallel""" """parallel id if node is in parallel"""
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
@@ -203,8 +171,7 @@ class QueueLoopNextEvent(AppQueueEvent):
parallel_mode_run_id: str | None = None parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id""" """iteration run in parallel mode run id"""
node_run_index: int node_run_index: int
output: Any | None = None # output for the current loop output: Any = None # output for the current loop
duration: float | None = None
class QueueLoopCompletedEvent(AppQueueEvent): class QueueLoopCompletedEvent(AppQueueEvent):
@@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_title: str
parallel_id: str | None = None parallel_id: str | None = None
"""parallel id if node is in parallel""" """parallel id if node is in parallel"""
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
@@ -229,9 +196,9 @@ class QueueLoopCompletedEvent(AppQueueEvent):
start_at: datetime start_at: datetime
node_run_index: int node_run_index: int
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, Any] | None = None metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0 steps: int = 0
error: str | None = None error: str | None = None
@@ -332,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
""" """
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: dict[str, Any] | None = None outputs: Mapping[str, object] = Field(default_factory=dict)
class QueueWorkflowFailedEvent(AppQueueEvent): class QueueWorkflowFailedEvent(AppQueueEvent):
@@ -352,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int exceptions_count: int
outputs: dict[str, Any] | None = None outputs: Mapping[str, object] = Field(default_factory=dict)
class QueueNodeStartedEvent(AppQueueEvent): class QueueNodeStartedEvent(AppQueueEvent):
@@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_title: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_run_index: int = 1 # FIXME(-LAN-): may not used
node_run_index: int = 1
predecessor_node_id: str | None = None predecessor_node_id: str | None = None
parallel_id: str | None = None parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime start_at: datetime
parallel_mode_run_id: str | None = None parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
agent_strategy: AgentNodeStrategyInit | None = None agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
provider_id: str
class QueueNodeSucceededEvent(AppQueueEvent): class QueueNodeSucceededEvent(AppQueueEvent):
""" """
@@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None parallel_id: str | None = None
"""parallel id if node is in parallel""" """parallel id if node is in parallel"""
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
@@ -411,16 +374,12 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""loop id if node is in loop""" """loop id if node is in loop"""
start_at: datetime start_at: datetime
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, Any] | None = None process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str | None = None error: str | None = None
"""single iteration duration map"""
iteration_duration_map: dict[str, float] | None = None
"""single loop duration map"""
loop_duration_map: dict[str, float] | None = None
class QueueAgentLogEvent(AppQueueEvent): class QueueAgentLogEvent(AppQueueEvent):
@@ -436,7 +395,7 @@ class QueueAgentLogEvent(AppQueueEvent):
error: str | None = None error: str | None = None
status: str status: str
data: Mapping[str, Any] data: Mapping[str, Any]
metadata: Mapping[str, Any] | None = None metadata: Mapping[str, object] = Field(default_factory=dict)
node_id: str node_id: str
@@ -445,81 +404,15 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
event: QueueEvent = QueueEvent.RETRY event: QueueEvent = QueueEvent.RETRY
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, Any] | None = None process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str error: str
retry_index: int # retry index retry_index: int # retry index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
class QueueNodeInLoopFailedEvent(AppQueueEvent):
"""
QueueNodeInLoopFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
class QueueNodeExceptionEvent(AppQueueEvent): class QueueNodeExceptionEvent(AppQueueEvent):
""" """
QueueNodeExceptionEvent entity QueueNodeExceptionEvent entity
@@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None parallel_id: str | None = None
"""parallel id if node is in parallel""" """parallel id if node is in parallel"""
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
@@ -545,9 +437,9 @@ class QueueNodeExceptionEvent(AppQueueEvent):
"""loop id if node is in loop""" """loop id if node is in loop"""
start_at: datetime start_at: datetime
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, Any] | None = None process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str error: str
@@ -563,24 +455,16 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_execution_id: str node_execution_id: str
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData
parallel_id: str | None = None parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None in_iteration_id: str | None = None
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: str | None = None in_loop_id: str | None = None
"""loop id if node is in loop""" """loop id if node is in loop"""
start_at: datetime start_at: datetime
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, Any] | None = None process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str error: str
@@ -610,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
""" """
event: QueueEvent = QueueEvent.ERROR event: QueueEvent = QueueEvent.ERROR
error: Any | None = None error: Any = None
class QueuePingEvent(AppQueueEvent): class QueuePingEvent(AppQueueEvent):
@@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
""" """
pass pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
error: str

View File

@@ -0,0 +1,14 @@
from typing import Any
from pydantic import BaseModel
class RagPipelineInvokeEntity(BaseModel):
pipeline_id: str
application_generate_entity: dict[str, Any]
user_id: str
tenant_id: str
workflow_id: str
streaming: bool
workflow_execution_id: str | None = None
workflow_thread_pool_id: str | None = None

View File

@@ -1,13 +1,13 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from enum import StrEnum, auto from enum import StrEnum
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel): class AnnotationReplyAccount(BaseModel):
@@ -55,32 +55,30 @@ class StreamEvent(StrEnum):
Stream event Stream event
""" """
PING = auto() PING = "ping"
ERROR = auto() ERROR = "error"
MESSAGE = auto() MESSAGE = "message"
MESSAGE_END = auto() MESSAGE_END = "message_end"
TTS_MESSAGE = auto() TTS_MESSAGE = "tts_message"
TTS_MESSAGE_END = auto() TTS_MESSAGE_END = "tts_message_end"
MESSAGE_FILE = auto() MESSAGE_FILE = "message_file"
MESSAGE_REPLACE = auto() MESSAGE_REPLACE = "message_replace"
AGENT_THOUGHT = auto() AGENT_THOUGHT = "agent_thought"
AGENT_MESSAGE = auto() AGENT_MESSAGE = "agent_message"
WORKFLOW_STARTED = auto() WORKFLOW_STARTED = "workflow_started"
WORKFLOW_FINISHED = auto() WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = auto() NODE_STARTED = "node_started"
NODE_FINISHED = auto() NODE_FINISHED = "node_finished"
NODE_RETRY = auto() NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = auto() ITERATION_STARTED = "iteration_started"
PARALLEL_BRANCH_FINISHED = auto() ITERATION_NEXT = "iteration_next"
ITERATION_STARTED = auto() ITERATION_COMPLETED = "iteration_completed"
ITERATION_NEXT = auto() LOOP_STARTED = "loop_started"
ITERATION_COMPLETED = auto() LOOP_NEXT = "loop_next"
LOOP_STARTED = auto() LOOP_COMPLETED = "loop_completed"
LOOP_NEXT = auto() TEXT_CHUNK = "text_chunk"
LOOP_COMPLETED = auto() TEXT_REPLACE = "text_replace"
TEXT_CHUNK = auto() AGENT_LOG = "agent_log"
TEXT_REPLACE = auto()
AGENT_LOG = auto()
class StreamResponse(BaseModel): class StreamResponse(BaseModel):
@@ -138,7 +136,7 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END event: StreamEvent = StreamEvent.MESSAGE_END
id: str id: str
metadata: dict = Field(default_factory=dict) metadata: Mapping[str, object] = Field(default_factory=dict)
files: Sequence[Mapping[str, Any]] | None = None files: Sequence[Mapping[str, Any]] | None = None
@@ -175,7 +173,7 @@ class AgentThoughtStreamResponse(StreamResponse):
thought: str | None = None thought: str | None = None
observation: str | None = None observation: str | None = None
tool: str | None = None tool: str | None = None
tool_labels: dict | None = None tool_labels: Mapping[str, object] = Field(default_factory=dict)
tool_input: str | None = None tool_input: str | None = None
message_files: list[str] | None = None message_files: list[str] | None = None
@@ -228,7 +226,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
elapsed_time: float elapsed_time: float
total_tokens: int total_tokens: int
total_steps: int total_steps: int
created_by: dict | None = None created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int created_at: int
finished_at: int finished_at: int
exceptions_count: int | None = 0 exceptions_count: int | None = 0
@@ -256,8 +254,9 @@ class NodeStartStreamResponse(StreamResponse):
index: int index: int
predecessor_node_id: str | None = None predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
created_at: int created_at: int
extras: dict = Field(default_factory=dict) extras: dict[str, object] = Field(default_factory=dict)
parallel_id: str | None = None parallel_id: str | None = None
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None parent_parallel_id: str | None = None
@@ -313,8 +312,11 @@ class NodeFinishStreamResponse(StreamResponse):
index: int index: int
predecessor_node_id: str | None = None predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Mapping[str, Any] | None = None process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True
status: str status: str
error: str | None = None error: str | None = None
elapsed_time: float elapsed_time: float
@@ -382,8 +384,11 @@ class NodeRetryStreamResponse(StreamResponse):
index: int index: int
predecessor_node_id: str | None = None predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Mapping[str, Any] | None = None process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False
status: str status: str
error: str | None = None error: str | None = None
elapsed_time: float elapsed_time: float
@@ -436,54 +441,6 @@ class NodeRetryStreamResponse(StreamResponse):
} }
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
status: str
error: str | None = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse): class IterationNodeStartStreamResponse(StreamResponse):
""" """
NodeStartStreamResponse entity NodeStartStreamResponse entity
@@ -502,8 +459,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
extras: dict = Field(default_factory=dict) extras: dict = Field(default_factory=dict)
metadata: Mapping = {} metadata: Mapping = {}
inputs: Mapping = {} inputs: Mapping = {}
parallel_id: str | None = None inputs_truncated: bool = False
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.ITERATION_STARTED event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str workflow_run_id: str
@@ -526,12 +482,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str title: str
index: int index: int
created_at: int created_at: int
pre_iteration_output: Any | None = None
extras: dict = Field(default_factory=dict) extras: dict = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None
duration: float | None = None
event: StreamEvent = StreamEvent.ITERATION_NEXT event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str workflow_run_id: str
@@ -553,18 +504,18 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_type: str node_type: str
title: str title: str
outputs: Mapping | None = None outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int created_at: int
extras: dict | None = None extras: dict | None = None
inputs: Mapping | None = None inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus status: WorkflowNodeExecutionStatus
error: str | None = None error: str | None = None
elapsed_time: float elapsed_time: float
total_tokens: int total_tokens: int
execution_metadata: Mapping | None = None execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int finished_at: int
steps: int steps: int
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str workflow_run_id: str
@@ -589,6 +540,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
extras: dict = Field(default_factory=dict) extras: dict = Field(default_factory=dict)
metadata: Mapping = {} metadata: Mapping = {}
inputs: Mapping = {} inputs: Mapping = {}
inputs_truncated: bool = False
parallel_id: str | None = None parallel_id: str | None = None
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
@@ -613,12 +565,11 @@ class LoopNodeNextStreamResponse(StreamResponse):
title: str title: str
index: int index: int
created_at: int created_at: int
pre_loop_output: Any | None = None pre_loop_output: Any = None
extras: dict = Field(default_factory=dict) extras: Mapping[str, object] = Field(default_factory=dict)
parallel_id: str | None = None parallel_id: str | None = None
parallel_start_node_id: str | None = None parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None parallel_mode_run_id: str | None = None
duration: float | None = None
event: StreamEvent = StreamEvent.LOOP_NEXT event: StreamEvent = StreamEvent.LOOP_NEXT
workflow_run_id: str workflow_run_id: str
@@ -640,14 +591,16 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
node_type: str node_type: str
title: str title: str
outputs: Mapping | None = None outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int created_at: int
extras: dict | None = None extras: dict | None = None
inputs: Mapping | None = None inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus status: WorkflowNodeExecutionStatus
error: str | None = None error: str | None = None
elapsed_time: float elapsed_time: float
total_tokens: int total_tokens: int
execution_metadata: Mapping | None = None execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int finished_at: int
steps: int steps: int
parallel_id: str | None = None parallel_id: str | None = None
@@ -757,7 +710,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
conversation_id: str conversation_id: str
message_id: str message_id: str
answer: str answer: str
metadata: dict = Field(default_factory=dict) metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int created_at: int
data: Data data: Data
@@ -777,7 +730,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
mode: str mode: str
message_id: str message_id: str
answer: str answer: str
metadata: dict = Field(default_factory=dict) metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int created_at: int
data: Data data: Data
@@ -825,7 +778,7 @@ class AgentLogStreamResponse(StreamResponse):
error: str | None = None error: str | None = None
status: str status: str
data: Mapping[str, Any] data: Mapping[str, Any]
metadata: Mapping[str, Any] | None = None metadata: Mapping[str, object] = Field(default_factory=dict)
node_id: str node_id: str
event: StreamEvent = StreamEvent.AGENT_LOG event: StreamEvent = StreamEvent.AGENT_LOG

View File

@@ -138,6 +138,8 @@ class MessageCycleManager:
:param event: event :param event: event
:return: :return:
""" """
if not self._application_generate_entity.app_config.additional_features:
raise ValueError("Additional features not found")
if self._application_generate_entity.app_config.additional_features.show_retrieve_source: if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata.retriever_resources = event.retriever_resources self._task_state.metadata.retriever_resources = event.retriever_resources

View File

@@ -109,7 +109,9 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueNodeSucceededEvent): elif isinstance(message.event, QueueNodeSucceededEvent):
if message.event.outputs is None: if message.event.outputs is None:
continue continue
self.msg_text += message.event.outputs.get("output", "") output = message.event.outputs.get("output", "")
if isinstance(output, str):
self.msg_text += output
self.last_message = message self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text) sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.max_sentence, 7): if len(sentence_arr) >= min(self.max_sentence, 7):
@@ -119,7 +121,7 @@ class AppGeneratorTTSPublisher:
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
) )
future_queue.put(futures_result) future_queue.put(futures_result)
if text_tmp: if isinstance(text_tmp, str):
self.msg_text = text_tmp self.msg_text = text_tmp
else: else:
self.msg_text = "" self.msg_text = ""

View File

@@ -105,6 +105,14 @@ class DifyAgentCallbackHandler(BaseModel):
self.current_loop += 1 self.current_loop += 1
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
"""Run on datasource start."""
if dify_config.DEBUG:
print_text(
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
color=self.color,
)
@property @property
def ignore_agent(self) -> bool: def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks.""" """Whether to ignore agent callbacks."""

View File

@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
from configs import dify_config
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class DatasourcePlugin(ABC):
entity: DatasourceEntity
runtime: DatasourceRuntime
icon: str
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
icon: str,
) -> None:
self.entity = entity
self.runtime = runtime
self.icon = icon
@abstractmethod
def datasource_provider_type(self) -> str:
"""
returns the type of the datasource provider
"""
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
icon=self.icon,
)
def get_icon_url(self, tenant_id: str) -> str:
return f"{dify_config.CONSOLE_API_URL}/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={self.icon}" # noqa: E501

Some files were not shown because too many files have changed in this diff Show More