Compare commits

..

42 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
85396a2af2 Initial plan 2026-03-09 08:05:17 +00:00
Dev Sharma
6c19e75969 test: improve unit tests for controllers.web (#32150)
Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
2026-03-09 15:58:34 +08:00
wangxiaolei
9970f4449a refactor: reuse redis connection instead of create new one (#32678)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 15:53:21 +08:00
Sense_wang
cbb19cce39 docs: use docker compose command consistently in README (#33077)
Co-authored-by: Contributor <contributor@example.com>
2026-03-09 15:02:30 +08:00
hj24
0aef09d630 feat: support relative mode for message clean command (#32834) 2026-03-09 14:32:35 +08:00
wangxiaolei
d2208ad43e fix: fix allow handle value is none (#33031) 2026-03-09 14:20:44 +08:00
非法操作
4a2ba058bb feat: when copy/paste multi nodes not require reconnect them (#32631) 2026-03-09 13:55:12 +08:00
非法操作
654e41d47f fix: workflow_as_tool not work with json input (#32554) 2026-03-09 13:54:54 +08:00
非法操作
ec5409756e feat: keep connections when change node (#31982) 2026-03-09 13:54:10 +08:00
Olexandr88
8b1ea3a8f5 refactor: deduplicate legacy section mapping in ConfigHelper (#32715) 2026-03-09 13:43:06 +08:00
yyh
f2d3feca66 fix(web): fix tool item text not vertically centered in block selector (#33148) 2026-03-09 13:38:11 +08:00
yyh
0590b09958 feat(web): add context menu primitive and dropdown link item (#33125) 2026-03-09 12:05:38 +08:00
wangxiaolei
66f9fde2fe fix: fix metadata filter condition not extract from {{}} (#33141)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 11:51:08 +08:00
Stephen Zhou
1811a855ab chore: update vinext, agentation, remove Prism in lexical (#33142) 2026-03-09 11:40:04 +08:00
Jiaquan Yi
322cd37de1 fix: handle backslash path separators in DOCX ZIP entries exported on…(#33129) (#33131)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 10:49:42 +08:00
rajatagarwal-oss
2cc0de9c1b test: unit test case for controllers.common module (#32056) 2026-03-09 09:45:42 +08:00
wangxiaolei
46098b2be6 refactor: use thread.Timer instead of time.sleep (#33121) 2026-03-09 09:38:16 +08:00
akashseth-ifp
7dcf94f48f test: remaining header component and increase branch coverage (#33052)
Co-authored-by: sahil <sahil@infocusp.com>
2026-03-09 09:18:11 +08:00
yyh
7869551afd fix(web): stabilize dayjs timezone tests against DST transitions (#33134) 2026-03-09 09:16:45 +08:00
CoralGarden52
c925d17e8f chore: add TypedDict related prompt to api/AGENTS.md (#33116)
Some checks failed
Mark stale issues and pull requests / stale (push) Has been cancelled
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-08 07:03:52 +09:00
Angel
dc2a53d834 feat: add files to message end pr32019 (#32242)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: fatelei <fatelei@gmail.com>
Co-authored-by: angel.k <angel.kolev@solaredge.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-07 20:01:12 +08:00
hj24
05ab107e73 feat: add export app messages (#32990)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-07 11:27:15 +08:00
pepsi
c016793efb refactor: pass KnowledgeConfiguration directly instead of dict (#32732)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: pepsi <pepsi@pepsidexuniji.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-06 22:15:32 +09:00
Coding On Star
a5bcbaebb7 feat(toast): add IToastProps type import to enhance type safety (#33096)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-06 19:22:55 +08:00
Saumya Talwani
f50e44b24a test: improve coverage for some test files (#32916)
Signed-off-by: edvatar <88481784+toroleapinc@users.noreply.github.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: Poojan <poojan@infocusp.com>
Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Pandaaaa906 <ye.pandaaaa906@gmail.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: heyszt <270985384@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Ijas <ijas.ahmd.ap@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: 木之本澪 <kinomotomiovo@gmail.com>
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
Co-authored-by: 不做了睡大觉 <64798754+stakeswky@users.noreply.github.com>
Co-authored-by: User <user@example.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: edvatar <88481784+toroleapinc@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Leilei <138381132+Inlei@users.noreply.github.com>
Co-authored-by: HaKu <104669497+haku-ink@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
Co-authored-by: Varun Chawla <34209028+veeceey@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: tda <95275462+tda1017@users.noreply.github.com>
Co-authored-by: root <root@DESKTOP-KQLO90N>
Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
Co-authored-by: Niels Kaspers <153818647+nielskaspers@users.noreply.github.com>
Co-authored-by: hj24 <mambahj24@gmail.com>
Co-authored-by: Tyson Cung <45380903+tysoncung@users.noreply.github.com>
Co-authored-by: Stephen Zhou <hi@hyoban.cc>
Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
Co-authored-by: slegarraga <64795732+slegarraga@users.noreply.github.com>
Co-authored-by: 99 <wh2099@pm.me>
Co-authored-by: Br1an <932039080@qq.com>
Co-authored-by: L1nSn0w <l1nsn0w@qq.com>
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
Co-authored-by: akkoaya <151345394+akkoaya@users.noreply.github.com>
Co-authored-by: 盐粒 Yanli <yanli@dify.ai>
Co-authored-by: lif <1835304752@qq.com>
Co-authored-by: weiguang li <codingpunk@gmail.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: HanWenbo <124024253+hwb96@users.noreply.github.com>
Co-authored-by: Coding On Star <447357187@qq.com>
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: Stable Genius <stablegenius043@gmail.com>
Co-authored-by: Stable Genius <259448942+stablegenius49@users.noreply.github.com>
Co-authored-by: ふるい <46769295+Echo0ff@users.noreply.github.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
2026-03-06 18:59:16 +08:00
Nite Knite
09347d5e8b chore: fix account dropdown test (#33093) 2026-03-06 18:19:02 +08:00
Stephen Zhou
299a893ac5 chore: bring back code-inspector-plugin and agentation (#33088)
Co-authored-by: zhsama <zhsama@users.noreply.github.com>
2026-03-06 17:01:18 +08:00
Junyan Chin
c477571553 perf: no longer record install count for auto upgrade (#33086) 2026-03-06 16:19:30 +08:00
QuantumGhost
d01acfc490 fix(api): fix the issue that workflow_runs.started_at is overwritten while resuming (#32851)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-06 15:41:30 +08:00
Stephen Zhou
f05f0be55f chore: use react-grab to replace code-inspector-plugin (#33078)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-06 14:54:24 +08:00
eux
e74cda6535 feat(tasks): isolate summary generation to dedicated dataset_summary queue (#32972) 2026-03-06 14:35:28 +08:00
Nite Knite
0490756ab2 chore: add support email env (#33075) 2026-03-06 14:29:29 +08:00
非法操作
dc31b07533 fix(type-check): resolve missing-attribute in app dataset join update handler (#33071) 2026-03-06 11:45:51 +08:00
Copilot
d1eaa41dd1 fix(i18n): correct French translation of "disabled" from medical term to UI-appropriate term (#33067)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-03-06 09:57:43 +08:00
非法操作
7ffa6c1849 fix: conversation var unexpected reset after HITL node (#32936) 2026-03-06 09:57:09 +08:00
kurokobo
ad81513b6a fix: show citations in advanced chat apps (#32985)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-06 09:56:14 +08:00
Lovish Arora
f751864ab3 fix(api): return inserted ids from Chroma and Clickzetta add_texts (#33065)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-06 09:49:53 +08:00
盐粒 Yanli
49dcf5e0d9 chore: add local pyrefly exclude workflow (#33059) 2026-03-06 09:49:23 +08:00
statxc
741d48560d refactor(api): add TypedDict definitions to models/model.py (#32925)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-06 08:42:54 +09:00
dependabot[bot]
6bd1be9e16 chore(deps): bump markdown from 3.5.2 to 3.8.1 in /api (#33064)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-06 07:41:55 +09:00
木之本澪
f76de73be4 test: migrate dataset permission service SQL tests to testcontainers (#32546)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-06 07:21:25 +09:00
dependabot[bot]
98ba091a50 chore(deps): bump dompurify from 3.3.0 to 3.3.2 in /web (#33062)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-06 06:48:59 +09:00
274 changed files with 27728 additions and 3634 deletions

View File

@@ -7,7 +7,7 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

View File

@@ -25,6 +25,10 @@ updates:
interval: "weekly"
open-pull-requests-limit: 2
groups:
lexical:
patterns:
- "lexical"
- "@lexical/*"
storybook:
patterns:
- "storybook"
@@ -33,5 +37,7 @@ updates:
patterns:
- "*"
exclude-patterns:
- "lexical"
- "@lexical/*"
- "storybook"
- "@storybook/*"

View File

@@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
"dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
"--loglevel",
"INFO"
],

View File

@@ -68,8 +68,9 @@ lint:
@echo "✅ Linting complete"
type-check:
@echo "📝 Running type checks (basedpyright + mypy)..."
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@./dev/pyrefly-check-local
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Type checks complete"
@@ -131,7 +132,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checks (basedpyright, mypy)"
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@@ -133,7 +133,7 @@ Star Dify on GitHub and be instantly notified of new releases.
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
#### Customizing Suggested Questions

View File

@@ -62,6 +62,22 @@ This is the default standard for backend code in this repo. Follow it for new co
- Code should usually include type annotations that match the repos current Python version (avoid untyped public APIs and “mystery” values).
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless theres a strong reason.
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
```python
from datetime import datetime
from typing import NotRequired, TypedDict
class UserProfile(TypedDict):
user_id: str
email: str
created_at: datetime
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
```python

View File

@@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from libs.db_migration_lock import DbMigrationAutoRenewLock
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
@@ -2598,15 +2599,29 @@ def migrate_oss(
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
required=False,
default=None,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
required=False,
default=None,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--from-days-ago",
type=int,
default=None,
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
)
@click.option(
"--before-days",
type=int,
default=None,
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
@@ -2618,8 +2633,10 @@ def migrate_oss(
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
from_days_ago: int | None,
before_days: int | None,
dry_run: bool,
):
"""
@@ -2630,18 +2647,70 @@ def clean_expired_messages(
start_at = time.perf_counter()
try:
abs_mode = start_from is not None and end_before is not None
rel_mode = before_days is not None
if abs_mode and rel_mode:
raise click.UsageError(
"Options are mutually exclusive: use either (--start-from,--end-before) "
"or (--from-days-ago,--before-days)."
)
if from_days_ago is not None and before_days is None:
raise click.UsageError("--from-days-ago must be used together with --before-days.")
if (start_from is None) ^ (end_before is None):
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
if not abs_mode and not rel_mode:
raise click.UsageError(
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
)
if rel_mode:
assert before_days is not None
if before_days < 0:
raise click.UsageError("--before-days must be >= 0.")
if from_days_ago is not None:
if from_days_ago < 0:
raise click.UsageError("--from-days-ago must be >= 0.")
if from_days_ago <= before_days:
raise click.UsageError("--from-days-ago must be greater than --before-days.")
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
if abs_mode:
assert start_from is not None
assert end_before is not None
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
elif from_days_ago is None:
assert before_days is not None
service = MessagesCleanService.from_days(
policy=policy,
days=before_days,
batch_size=batch_size,
dry_run=dry_run,
)
else:
assert before_days is not None
assert from_days_ago is not None
now = naive_utc_now()
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=now - datetime.timedelta(days=from_days_ago),
end_before=now - datetime.timedelta(days=before_days),
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()
@@ -2668,3 +2737,77 @@ def clean_expired_messages(
raise
click.echo(click.style("messages cleanup completed.", fg="green"))
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
@click.option("--app-id", required=True, help="Application ID to export messages for.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--filename",
required=True,
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
)
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
def export_app_messages(
app_id: str,
start_from: datetime.datetime | None,
end_before: datetime.datetime,
filename: str,
use_cloud_storage: bool,
batch_size: int,
dry_run: bool,
):
if start_from and start_from >= end_before:
raise click.UsageError("--start-from must be before --end-before.")
from services.retention.conversation.message_export_service import AppMessageExportService
try:
validated_filename = AppMessageExportService.validate_export_filename(filename)
except ValueError as e:
raise click.BadParameter(str(e), param_hint="--filename") from e
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
start_at = time.perf_counter()
try:
service = AppMessageExportService(
app_id=app_id,
end_before=end_before,
filename=validated_filename,
start_from=start_from,
batch_size=batch_size,
use_cloud_storage=use_cloud_storage,
dry_run=dry_run,
)
stats = service.run()
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"export_app_messages: completed in {elapsed:.2f}s\n"
f" - Batches: {stats.batches}\n"
f" - Total messages: {stats.total_messages}\n"
f" - Messages with feedback: {stats.messages_with_feedback}\n"
f" - Total feedbacks: {stats.total_feedbacks}",
fg="green",
)
)
except Exception as e:
elapsed = time.perf_counter() - start_at
logger.exception("export_app_messages failed")
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
raise

View File

@@ -1,3 +1,5 @@
from typing import Any, cast
from controllers.common import fields
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
@@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource):
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
user_input_form = features_dict.get("user_input_form", [])

View File

@@ -1,3 +1,5 @@
from typing import Any, cast
from flask_restx import Resource
from controllers.common.fields import Parameters
@@ -33,14 +35,14 @@ class AppParameterApi(Resource):
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
user_input_form = features_dict.get("user_input_form", [])

View File

@@ -1,4 +1,5 @@
import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource
@@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource):
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
user_input_form = features_dict.get("user_input_form", [])

View File

@@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotCompletionAppError()
raise NotChatAppError()
message_id = str(message_id)

View File

@@ -1,10 +1,13 @@
from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict:
return None
@@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager:
if sensitive_word_avoidance_dict.get("enabled"):
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config"),
config=sensitive_word_avoidance_dict.get("config", {}),
)
else:
return None

View File

@@ -1,10 +1,13 @@
from typing import Any, cast
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
from models.model import AppModelConfigDict
class AgentConfigManager:
@classmethod
def convert(cls, config: dict) -> AgentEntity | None:
def convert(cls, config: AppModelConfigDict) -> AgentEntity | None:
"""
Convert model config to model config
@@ -28,17 +31,17 @@ class AgentConfigManager:
agent_tools = []
for tool in agent_dict.get("tools", []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
tool_dict = cast(dict[str, Any], tool)
if len(tool_dict) >= 4:
if "enabled" not in tool_dict or not tool_dict["enabled"]:
continue
agent_tool_properties = {
"provider_type": tool["provider_type"],
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
"credential_id": tool.get("credential_id", None),
"provider_type": tool_dict["provider_type"],
"provider_id": tool_dict["provider_id"],
"tool_name": tool_dict["tool_name"],
"tool_parameters": tool_dict.get("tool_parameters", {}),
"credential_id": tool_dict.get("credential_id", None),
}
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
@@ -47,7 +50,8 @@ class AgentConfigManager:
"react_router",
"router",
}:
agent_prompt = agent_dict.get("prompt", None) or {}
agent_prompt_raw = agent_dict.get("prompt", None)
agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {}
# check model mode
model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == "completion":
@@ -75,7 +79,7 @@ class AgentConfigManager:
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get("max_iteration", 10),
max_iteration=cast(int, agent_dict.get("max_iteration", 10)),
)
return None

View File

@@ -1,5 +1,5 @@
import uuid
from typing import Literal, cast
from typing import Any, Literal, cast
from core.app.app_config.entities import (
DatasetEntity,
@@ -8,13 +8,13 @@ from core.app.app_config.entities import (
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode
from models.model import AppMode, AppModelConfigDict
from services.dataset_service import DatasetService
class DatasetConfigManager:
@classmethod
def convert(cls, config: dict) -> DatasetEntity | None:
def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None:
"""
Convert model config to model config
@@ -25,11 +25,15 @@ class DatasetConfigManager:
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
for dataset in datasets.get("datasets", []):
if not isinstance(dataset, dict):
continue
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != "dataset":
continue
dataset = dataset["dataset"]
if not isinstance(dataset, dict):
continue
if "enabled" not in dataset or not dataset["enabled"]:
continue
@@ -47,15 +51,14 @@ class DatasetConfigManager:
agent_dict = config.get("agent_mode", {})
for tool in agent_dict.get("tools", []):
keys = tool.keys()
if len(keys) == 1:
if len(tool) == 1:
# old standard
key = list(tool.keys())[0]
if key != "dataset":
continue
tool_item = tool[key]
tool_item = cast(dict[str, Any], tool)[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue

View File

@@ -5,12 +5,13 @@ from core.app.app_config.entities import ModelConfigEntity
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from models.model import AppModelConfigDict
from models.provider_ids import ModelProviderID
class ModelConfigManager:
@classmethod
def convert(cls, config: dict) -> ModelConfigEntity:
def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity:
"""
Convert model config to model config
@@ -22,7 +23,7 @@ class ModelConfigManager:
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get("completion_params")
completion_params = model_config.get("completion_params") or {}
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]

View File

@@ -1,3 +1,5 @@
from typing import Any
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
@@ -6,12 +8,12 @@ from core.app.app_config.entities import (
)
from core.prompt.simple_prompt_transform import ModelMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode
from models.model import AppMode, AppModelConfigDict
class PromptTemplateConfigManager:
@classmethod
def convert(cls, config: dict) -> PromptTemplateEntity:
def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
@@ -40,14 +42,15 @@ class PromptTemplateConfigManager:
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
completion_prompt_template_params: dict[str, Any] = {
"prompt": completion_prompt_config["prompt"]["text"],
}
if "conversation_histories_role" in completion_prompt_config:
conv_role = completion_prompt_config.get("conversation_histories_role")
if conv_role:
completion_prompt_template_params["role_prefix"] = {
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
"user": conv_role["user_prefix"],
"assistant": conv_role["assistant_prefix"],
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(

View File

@@ -1,8 +1,10 @@
import re
from typing import cast
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
from models.model import AppModelConfigDict
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[
@@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
class BasicVariablesConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
"""
Convert model config to model config
@@ -51,7 +53,9 @@ class BasicVariablesConfigManager:
external_data_variables.append(
ExternalDataVariableEntity(
variable=variable["variable"], type=variable["type"], config=variable["config"]
variable=variable["variable"],
type=variable.get("type", ""),
config=variable.get("config", {}),
)
)
elif variable_type in {
@@ -64,10 +68,10 @@ class BasicVariablesConfigManager:
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
type=cast(VariableEntityType, variable_type),
variable=variable["variable"],
description=variable.get("description") or "",
label=variable.get("label"),
label=variable["label"],
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],

View File

@@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig):
app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str
app_model_config_dict: dict
app_model_config_dict: dict[str, Any]
model: ModelConfigEntity
prompt_template: PromptTemplateEntity
dataset: DatasetEntity | None = None

View File

@@ -516,8 +516,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
graph_runtime_state=validated_state,
)
yield from self._handle_advanced_chat_message_end_event(
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_partial_success_event(
self,
@@ -538,6 +540,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
exceptions_count=event.exceptions_count,
)
yield from self._handle_advanced_chat_message_end_event(
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
@@ -854,6 +859,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
yield from self._handle_workflow_paused_event(event)
break
case QueueWorkflowSucceededEvent():
yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager)
break
case QueueWorkflowPartialSuccessEvent():
yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
break

View File

@@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.entities.agent_entities import PlanningStrategy
from models.model import App, AppMode, AppModelConfig, Conversation
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
@@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model: App,
app_model_config: AppModelConfig,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
override_config_dict: AppModelConfigDict | None = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
@@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict or {}
if not override_config_dict:
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentChatAppConfig(
@@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
app_model_config_dict=cast(dict[str, Any], config_dict),
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict:
"""
Validate for agent chat app model config
@@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
return cast(AppModelConfigDict, filtered_config)
@classmethod
def validate_agent_mode_and_set_defaults(

View File

@@ -1,3 +1,5 @@
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig, Conversation
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
class ChatAppConfig(EasyUIBasedAppConfig):
@@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_model: App,
app_model_config: AppModelConfig,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
override_config_dict: AppModelConfigDict | None = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config
@@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
app_model_config_dict=cast(dict[str, Any], config_dict),
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict):
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
"""
Validate for chat app model config
@@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
return cast(AppModelConfigDict, filtered_config)

View File

@@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner):
memory=memory,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
vision_enabled=bool(
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
.get("image", {})
.get("enabled", False)
),
)
context_files = retrieved_files or []

View File

@@ -1,3 +1,5 @@
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict
class CompletionAppConfig(EasyUIBasedAppConfig):
@@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config
@@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict or {}
if not override_config_dict:
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = CompletionAppConfig(
@@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
app_model_config_dict=cast(dict[str, Any], config_dict),
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict):
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
"""
Validate for completion app model config
@@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
return cast(AppModelConfigDict, filtered_config)

View File

@@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")
completion_params = model_dict.get("completion_params", {})
completion_params["temperature"] = 0.9
model_dict["completion_params"] = completion_params
override_model_config_dict["model"] = model_dict

View File

@@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner):
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
vision_enabled=bool(
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
.get("image", {})
.get("enabled", False)
),
)
context_files = retrieved_files or []

View File

@@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Union, cast
from typing import Any, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -44,14 +44,13 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.app.task_pipeline.message_file_utils import prepare_file_dict
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
@@ -219,14 +218,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech"))
if (
text_to_speech_dict
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None)
)
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
@@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
# Fetch files associated with this message
files = None
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if message_files:
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
upload_file_ids = list(
dict.fromkeys(
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
)
)
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
files_list = []
for message_file in message_files:
file_dict = prepare_file_dict(message_file, upload_files_map)
files_list.append(file_dict)
files = files_list or None
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=metadata_dict,
files=files,
)
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.

View File

@@ -1,7 +1,6 @@
import hashlib
import logging
import time
from threading import Thread
from threading import Thread, Timer
from typing import Union
from flask import Flask, current_app
@@ -96,9 +95,9 @@ class MessageCycleManager:
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
time.sleep(1)
thread = Thread(
target=self._generate_conversation_name_worker,
thread = Timer(
1,
self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation_id,

View File

@@ -0,0 +1,76 @@
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
"""
Prepare file dictionary for message end stream response.
:param message_file: MessageFile instance
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
:return: Dictionary containing file information
"""
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
return {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}

View File

@@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Union
from typing import Any, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if workflow is None:
raise ValueError("unexpected app type")
features_dict = workflow.features_dict
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = app_model_config.to_dict()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
user_input_form = features_dict.get("user_input_form", [])

View File

@@ -65,7 +65,7 @@ class ChromaVector(BaseVector):
self._client.get_or_create_collection(collection_name)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
@@ -73,6 +73,7 @@ class ChromaVector(BaseVector):
collection = self._client.get_or_create_collection(self._collection_name)
# FIXME: chromadb using numpy array, fix the type error later
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
return uuids
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)

View File

@@ -605,25 +605,36 @@ class ClickzettaVector(BaseVector):
logger.warning("Failed to create inverted index: %s", e)
# Continue without inverted index - full-text search will fall back to LIKE
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
"""Add documents with embeddings to the collection."""
if not documents:
return
return []
batch_size = self._config.batch_size
total_batches = (len(documents) + batch_size - 1) // batch_size
added_ids = []
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
batch_doc_ids = []
for doc in batch_docs:
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))))
added_ids.extend(batch_doc_ids)
# Execute batch insert through write queue
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
self._execute_write(
self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches
)
return added_ids
def _insert_batch(
self,
batch_docs: list[Document],
batch_embeddings: list[list[float]],
batch_doc_ids: list[str],
batch_index: int,
batch_size: int,
total_batches: int,
@@ -641,14 +652,9 @@ class ClickzettaVector(BaseVector):
data_rows = []
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
for doc, embedding in zip(batch_docs, batch_embeddings):
for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids):
# Optimized: minimal checks for common case, fallback for edge cases
metadata = doc.metadata or {}
if not isinstance(metadata, dict):
metadata = {}
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
# Fast path for JSON serialization
try:

View File

@@ -194,6 +194,13 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Create a new database session
with self._session_factory() as session:
existing_model = session.get(WorkflowRun, db_model.id)
if existing_model:
if existing_model.tenant_id != self._tenant_id:
raise ValueError("Unauthorized access to workflow run")
# Preserve the original start time for pause/resume flows.
db_model.created_at = existing_model.created_at
# SQLAlchemy merge intelligently handles both insert and update operations
# based on the presence of the primary key
session.merge(db_model)

View File

@@ -37,6 +37,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT,
}

View File

@@ -4,6 +4,7 @@ import json
import logging
import os
import tempfile
import zipfile
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
@@ -82,8 +83,18 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
value = variable.value
inputs = {"variable_selector": variable_selector}
if isinstance(value, list):
value = list(filter(lambda x: x, value))
process_data = {"documents": value if isinstance(value, list) else [value]}
if not value:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": ArrayStringSegment(value=[])},
)
try:
if isinstance(value, list):
extracted_text_list = [
@@ -111,6 +122,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
@@ -385,6 +397,32 @@ def parser_docx_part(block, doc: Document, content_items, i):
content_items.append((i, "table", Table(block, doc)))
def _normalize_docx_zip(file_content: bytes) -> bytes:
"""
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
ZIP entry names use backslash (\\) as path separator instead of the forward
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
"word\\document.xml" is never found when python-docx looks for
"word/document.xml", which triggers a KeyError about a missing relationship.
This function rewrites the ZIP in-memory, normalizing all entry names to
use forward slashes without touching any actual document content.
"""
try:
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
out_buf = io.BytesIO()
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
for item in zin.infolist():
data = zin.read(item.filename)
# Normalize backslash path separators to forward slash
item.filename = item.filename.replace("\\", "/")
zout.writestr(item, data)
return out_buf.getvalue()
except zipfile.BadZipFile:
# Not a valid zip — return as-is and let python-docx report the real error
return file_content
def _extract_text_from_docx(file_content: bytes) -> str:
"""
Extract text from a DOCX file.
@@ -392,7 +430,15 @@ def _extract_text_from_docx(file_content: bytes) -> str:
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
try:
doc = docx.Document(doc_file)
except Exception as e:
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
# Some DOCX files exported by tools like Evernote on Windows use
# backslash path separators in ZIP entries and/or single-quoted XML
# attributes, both of which break python-docx on Linux. Normalize and retry.
file_content = _normalize_docx_zip(file_content)
doc = docx.Document(io.BytesIO(file_content))
text = []
# Keep track of paragraph and table positions

View File

@@ -23,7 +23,11 @@ from dify_graph.variables import (
)
from dify_graph.variables.segments import ArrayObjectSegment
from .entities import KnowledgeRetrievalNodeData
from .entities import (
Condition,
KnowledgeRetrievalNodeData,
MetadataFilteringCondition,
)
from .exc import (
KnowledgeRetrievalNodeError,
RateLimitExceededError,
@@ -116,7 +120,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@@ -171,6 +175,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
resolved_metadata_conditions = (
self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions)
if node_data.metadata_filtering_conditions
else None
)
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
@@ -189,7 +199,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
)
@@ -247,7 +257,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
@@ -256,6 +266,48 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
def _resolve_metadata_filtering_conditions(
self, conditions: MetadataFilteringCondition
) -> MetadataFilteringCondition:
if conditions.conditions is None:
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator,
conditions=None,
)
variable_pool = self.graph_runtime_state.variable_pool
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = segment_group.value[0].to_object()
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values = []
for v in value: # type: ignore
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(segment_group.value[0].to_object())
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
else:
resolved_value = value
resolved_conditions.append(
Condition(
name=cond.name,
comparison_operator=cond.comparison_operator,
value=resolved_value,
)
)
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator or "and",
conditions=resolved_conditions,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -65,9 +65,15 @@ class VariablePool(BaseModel):
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
# Add conversation variables to the variable pool. When restoring from a serialized
# snapshot, `variable_dictionary` already carries the latest runtime values.
# In that case, keep existing entries instead of overwriting them with the
# bootstrap list.
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
selector = (CONVERSATION_VARIABLE_NODE_ID, var.name)
if self._has(selector):
continue
self.add(selector, var)
# Add rag pipeline variables to the variable pool
if self.rag_pipeline_variables:
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)

View File

@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

View File

@@ -1,3 +1,5 @@
from typing import Any, cast
from sqlalchemy import select
from events.app_event import app_model_config_was_updated
@@ -54,9 +56,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[s
continue
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
tool_config = cast(dict[str, Any], list(tool.values())[0])
if tool_type == "dataset":
dataset_ids.add(tool_config.get("id"))
dataset_id = tool_config.get("id")
if isinstance(dataset_id, str):
dataset_ids.add(dataset_id)
# get dataset from dataset_configs
dataset_configs = app_model_config.dataset_configs_dict

View File

@@ -13,6 +13,7 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
extract_unique_plugins,
file_usage,
@@ -66,6 +67,7 @@ def init_app(app: DifyApp):
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -66,6 +66,7 @@ def run_migrations_offline():
context.configure(
url=url, target_metadata=get_metadata(), literal_binds=True
)
logger.info("Generating offline migration SQL with url: %s", url)
with context.begin_transaction():
context.run_migrations()

View File

@@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -15,6 +15,7 @@ from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
@@ -36,6 +37,259 @@ if TYPE_CHECKING:
from .workflow import Workflow
# --- TypedDict definitions for structured dict return types ---
class EnabledConfig(TypedDict):
enabled: bool
class EmbeddingModelInfo(TypedDict):
embedding_provider_name: str
embedding_model_name: str
class AnnotationReplyDisabledConfig(TypedDict):
enabled: Literal[False]
class AnnotationReplyEnabledConfig(TypedDict):
id: str
enabled: Literal[True]
score_threshold: float
embedding_model: EmbeddingModelInfo
AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig
class SensitiveWordAvoidanceConfig(TypedDict):
enabled: bool
type: str
config: dict[str, Any]
class AgentToolConfig(TypedDict):
provider_type: str
provider_id: str
tool_name: str
tool_parameters: dict[str, Any]
plugin_unique_identifier: NotRequired[str | None]
credential_id: NotRequired[str | None]
class AgentModeConfig(TypedDict):
enabled: bool
strategy: str | None
tools: list[AgentToolConfig | dict[str, Any]]
prompt: str | None
class ImageUploadConfig(TypedDict):
enabled: bool
number_limits: int
detail: str
transfer_methods: list[str]
class FileUploadConfig(TypedDict):
image: ImageUploadConfig
class DeletedToolInfo(TypedDict):
type: str
tool_name: str
provider_id: str
class ExternalDataToolConfig(TypedDict):
enabled: bool
variable: str
type: str
config: dict[str, Any]
class UserInputFormItemConfig(TypedDict):
variable: str
label: str
description: NotRequired[str]
required: NotRequired[bool]
max_length: NotRequired[int]
options: NotRequired[list[str]]
default: NotRequired[str]
type: NotRequired[str]
config: NotRequired[dict[str, Any]]
# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig}
UserInputFormItem = dict[str, UserInputFormItemConfig]
class DatasetConfigs(TypedDict):
retrieval_model: str
datasets: NotRequired[dict[str, Any]]
top_k: NotRequired[int]
score_threshold: NotRequired[float]
score_threshold_enabled: NotRequired[bool]
reranking_model: NotRequired[dict[str, Any] | None]
weights: NotRequired[dict[str, Any] | None]
reranking_enabled: NotRequired[bool]
reranking_mode: NotRequired[str]
metadata_filtering_mode: NotRequired[str]
metadata_model_config: NotRequired[dict[str, Any] | None]
metadata_filtering_conditions: NotRequired[dict[str, Any] | None]
class ChatPromptMessage(TypedDict):
text: str
role: str
class ChatPromptConfig(TypedDict, total=False):
prompt: list[ChatPromptMessage]
class CompletionPromptText(TypedDict):
text: str
class ConversationHistoriesRole(TypedDict):
user_prefix: str
assistant_prefix: str
class CompletionPromptConfig(TypedDict):
prompt: CompletionPromptText
conversation_histories_role: NotRequired[ConversationHistoriesRole]
class ModelConfig(TypedDict):
provider: str
name: str
mode: str
completion_params: NotRequired[dict[str, Any]]
class AppModelConfigDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: EnabledConfig
speech_to_text: EnabledConfig
text_to_speech: EnabledConfig
retriever_resource: EnabledConfig
annotation_reply: AnnotationReplyConfig
more_like_this: EnabledConfig
sensitive_word_avoidance: SensitiveWordAvoidanceConfig
external_data_tools: list[ExternalDataToolConfig]
model: ModelConfig
user_input_form: list[UserInputFormItem]
dataset_query_variable: str | None
pre_prompt: str | None
agent_mode: AgentModeConfig
prompt_type: str
chat_prompt_config: ChatPromptConfig
completion_prompt_config: CompletionPromptConfig
dataset_configs: DatasetConfigs
file_upload: FileUploadConfig
# Added dynamically in Conversation.model_config
model_id: NotRequired[str | None]
provider: NotRequired[str | None]
class ConversationDict(TypedDict):
id: str
app_id: str
app_model_config_id: str | None
model_provider: str | None
override_model_configs: str | None
model_id: str | None
mode: str
name: str
summary: str | None
inputs: dict[str, Any]
introduction: str | None
system_instruction: str | None
system_instruction_tokens: int
status: str
invoke_from: str | None
from_source: str
from_end_user_id: str | None
from_account_id: str | None
read_at: datetime | None
read_account_id: str | None
dialogue_count: int
created_at: datetime
updated_at: datetime
class MessageDict(TypedDict):
id: str
app_id: str
conversation_id: str
model_id: str | None
inputs: dict[str, Any]
query: str
total_price: Decimal | None
message: dict[str, Any]
answer: str
status: str
error: str | None
message_metadata: dict[str, Any]
from_source: str
from_end_user_id: str | None
from_account_id: str | None
created_at: str
updated_at: str
agent_based: bool
workflow_run_id: str | None
class MessageFeedbackDict(TypedDict):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None
from_source: str
from_end_user_id: str | None
from_account_id: str | None
created_at: str
updated_at: str
class MessageFileInfo(TypedDict, total=False):
belongs_to: str | None
upload_file_id: str | None
id: str
tenant_id: str
type: str
transfer_method: str
remote_url: str | None
related_id: str | None
filename: str | None
extension: str | None
mime_type: str | None
size: int
dify_model_identity: str
url: str | None
class ExtraContentDict(TypedDict, total=False):
type: str
workflow_run_id: str
class TraceAppConfigDict(TypedDict):
id: str
app_id: str
tracing_provider: str | None
tracing_config: dict[str, Any]
is_active: bool
created_at: str | None
updated_at: str | None
class DifySetup(TypeBase):
__tablename__ = "dify_setups"
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
@@ -176,7 +430,7 @@ class App(Base):
return str(self.mode)
@property
def deleted_tools(self) -> list[dict[str, str]]:
def deleted_tools(self) -> list[DeletedToolInfo]:
from core.tools.tool_manager import ToolManager, ToolProviderType
from services.plugin.plugin_service import PluginService
@@ -257,7 +511,7 @@ class App(Base):
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
}
deleted_tools: list[dict[str, str]] = []
deleted_tools: list[DeletedToolInfo] = []
for tool in tools:
keys = list(tool.keys())
@@ -364,35 +618,38 @@ class AppModelConfig(TypeBase):
return app
@property
def model_dict(self) -> dict[str, Any]:
return json.loads(self.model) if self.model else {}
def model_dict(self) -> ModelConfig:
return cast(ModelConfig, json.loads(self.model) if self.model else {})
@property
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
@property
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
return (
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
return cast(
EnabledConfig,
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
else {"enabled": False}
else {"enabled": False},
)
@property
def speech_to_text_dict(self) -> dict[str, Any]:
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
def speech_to_text_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
@property
def text_to_speech_dict(self) -> dict[str, Any]:
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
def text_to_speech_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
@property
def retriever_resource_dict(self) -> dict[str, Any]:
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
def retriever_resource_dict(self) -> EnabledConfig:
return cast(
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
)
@property
def annotation_reply_dict(self) -> dict[str, Any]:
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
@@ -415,56 +672,62 @@ class AppModelConfig(TypeBase):
return {"enabled": False}
@property
def more_like_this_dict(self) -> dict[str, Any]:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
def more_like_this_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
@property
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
return (
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
return cast(
SensitiveWordAvoidanceConfig,
json.loads(self.sensitive_word_avoidance)
if self.sensitive_word_avoidance
else {"enabled": False, "type": "", "configs": []}
else {"enabled": False, "type": "", "config": {}},
)
@property
def external_data_tools_list(self) -> list[dict[str, Any]]:
def external_data_tools_list(self) -> list[ExternalDataToolConfig]:
return json.loads(self.external_data_tools) if self.external_data_tools else []
@property
def user_input_form_list(self) -> list[dict[str, Any]]:
def user_input_form_list(self) -> list[UserInputFormItem]:
return json.loads(self.user_input_form) if self.user_input_form else []
@property
def agent_mode_dict(self) -> dict[str, Any]:
return (
def agent_mode_dict(self) -> AgentModeConfig:
return cast(
AgentModeConfig,
json.loads(self.agent_mode)
if self.agent_mode
else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
else {"enabled": False, "strategy": None, "tools": [], "prompt": None},
)
@property
def chat_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
def chat_prompt_config_dict(self) -> ChatPromptConfig:
return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {})
@property
def completion_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
def completion_prompt_config_dict(self) -> CompletionPromptConfig:
return cast(
CompletionPromptConfig,
json.loads(self.completion_prompt_config) if self.completion_prompt_config else {},
)
@property
def dataset_configs_dict(self) -> dict[str, Any]:
def dataset_configs_dict(self) -> DatasetConfigs:
if self.dataset_configs:
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
dataset_configs = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
return dataset_configs
return cast(DatasetConfigs, dataset_configs)
return {
"retrieval_model": "multiple",
}
@property
def file_upload_dict(self) -> dict[str, Any]:
return (
def file_upload_dict(self) -> FileUploadConfig:
return cast(
FileUploadConfig,
json.loads(self.file_upload)
if self.file_upload
else {
@@ -474,10 +737,10 @@ class AppModelConfig(TypeBase):
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
}
},
)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> AppModelConfigDict:
return {
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
@@ -501,36 +764,42 @@ class AppModelConfig(TypeBase):
"file_upload": self.file_upload_dict,
}
def from_model_config_dict(self, model_config: Mapping[str, Any]):
def from_model_config_dict(self, model_config: AppModelConfigDict):
self.opening_statement = model_config.get("opening_statement")
self.suggested_questions = (
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
)
self.suggested_questions_after_answer = (
json.dumps(model_config["suggested_questions_after_answer"])
json.dumps(model_config.get("suggested_questions_after_answer"))
if model_config.get("suggested_questions_after_answer")
else None
)
self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None
self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None
self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None
self.speech_to_text = (
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
)
self.text_to_speech = (
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
)
self.more_like_this = (
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
)
self.sensitive_word_avoidance = (
json.dumps(model_config["sensitive_word_avoidance"])
json.dumps(model_config.get("sensitive_word_avoidance"))
if model_config.get("sensitive_word_avoidance")
else None
)
self.external_data_tools = (
json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
)
self.model = json.dumps(model_config["model"]) if model_config.get("model") else None
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
self.user_input_form = (
json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
)
self.dataset_query_variable = model_config.get("dataset_query_variable")
self.pre_prompt = model_config["pre_prompt"]
self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None
self.pre_prompt = model_config.get("pre_prompt")
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
self.retriever_resource = (
json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
)
self.prompt_type = model_config.get("prompt_type", "simple")
self.chat_prompt_config = (
@@ -823,24 +1092,26 @@ class Conversation(Base):
self._inputs = inputs
@property
def model_config(self):
model_config = {}
def model_config(self) -> AppModelConfigDict:
model_config = cast(AppModelConfigDict, {})
app_model_config: AppModelConfig | None = None
if self.mode == AppMode.ADVANCED_CHAT:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
model_config = override_model_configs
model_config = cast(AppModelConfigDict, override_model_configs)
else:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
if "model" in override_model_configs:
# where is app_id?
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(
cast(AppModelConfigDict, override_model_configs)
)
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
@@ -1015,7 +1286,7 @@ class Conversation(Base):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> ConversationDict:
return {
"id": self.id,
"app_id": self.app_id,
@@ -1295,7 +1566,7 @@ class Message(Base):
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property
def message_files(self) -> list[dict[str, Any]]:
def message_files(self) -> list[MessageFileInfo]:
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
@@ -1350,10 +1621,13 @@ class Message(Base):
)
files.append(file)
result: list[dict[str, Any]] = [
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
]
result = cast(
list[MessageFileInfo],
[
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
],
)
db.session.commit()
return result
@@ -1363,7 +1637,7 @@ class Message(Base):
self._extra_contents = list(contents)
@property
def extra_contents(self) -> list[dict[str, Any]]:
def extra_contents(self) -> list[ExtraContentDict]:
return getattr(self, "_extra_contents", [])
@property
@@ -1379,7 +1653,7 @@ class Message(Base):
return None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> MessageDict:
return {
"id": self.id,
"app_id": self.app_id,
@@ -1403,7 +1677,7 @@ class Message(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Message:
def from_dict(cls, data: MessageDict) -> Message:
return cls(
id=data["id"],
app_id=data["app_id"],
@@ -1463,7 +1737,7 @@ class MessageFeedback(TypeBase):
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> MessageFeedbackDict:
return {
"id": str(self.id),
"app_id": str(self.app_id),
@@ -1726,8 +2000,8 @@ class AppMCPServer(TypeBase):
return result
@property
def parameters_dict(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.parameters))
def parameters_dict(self) -> dict[str, str]:
return cast(dict[str, str], json.loads(self.parameters))
class Site(Base):
@@ -2167,7 +2441,7 @@ class TraceAppConfig(TypeBase):
def tracing_config_str(self) -> str:
return json.dumps(self.tracing_config_dict)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> TraceAppConfigDict:
return {
"id": self.id,
"app_id": self.app_id,

View File

@@ -35,7 +35,7 @@ dependencies = [
"jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
"markdown~=3.5.1",
"markdown~=3.8.1",
"mlflow-skinny>=3.0.0",
"numpy~=1.26.4",
"openpyxl~=3.1.5",
@@ -247,3 +247,13 @@ module = [
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
]
ignore_errors = true
[tool.pyrefly]
project-includes = ["."]
project-excludes = [
".venv",
"migrations/",
]
python-platform = "linux"
python-version = "3.11.0"
infer-with-first-use = false

View File

@@ -0,0 +1,200 @@
configs/middleware/cache/redis_pubsub_config.py
controllers/console/app/annotation.py
controllers/console/app/app.py
controllers/console/app/app_import.py
controllers/console/app/mcp_server.py
controllers/console/app/site.py
controllers/console/auth/email_register.py
controllers/console/human_input_form.py
controllers/console/init_validate.py
controllers/console/ping.py
controllers/console/setup.py
controllers/console/version.py
controllers/console/workspace/trigger_providers.py
controllers/service_api/app/annotation.py
controllers/web/workflow_events.py
core/agent/fc_agent_runner.py
core/app/apps/advanced_chat/app_generator.py
core/app/apps/advanced_chat/app_runner.py
core/app/apps/advanced_chat/generate_task_pipeline.py
core/app/apps/agent_chat/app_generator.py
core/app/apps/base_app_generate_response_converter.py
core/app/apps/base_app_generator.py
core/app/apps/chat/app_generator.py
core/app/apps/common/workflow_response_converter.py
core/app/apps/completion/app_generator.py
core/app/apps/pipeline/pipeline_generator.py
core/app/apps/pipeline/pipeline_runner.py
core/app/apps/workflow/app_generator.py
core/app/apps/workflow/app_runner.py
core/app/apps/workflow/generate_task_pipeline.py
core/app/apps/workflow_app_runner.py
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/datasource/datasource_manager.py
core/external_data_tool/api/api.py
core/llm_generator/llm_generator.py
core/llm_generator/output_parser/structured_output.py
core/mcp/mcp_client.py
core/ops/aliyun_trace/data_exporter/traceclient.py
core/ops/arize_phoenix_trace/arize_phoenix_trace.py
core/ops/mlflow_trace/mlflow_trace.py
core/ops/ops_trace_manager.py
core/ops/tencent_trace/client.py
core/ops/tencent_trace/utils.py
core/plugin/backwards_invocation/base.py
core/plugin/backwards_invocation/model.py
core/prompt/utils/extract_thread_messages.py
core/rag/datasource/keyword/jieba/jieba.py
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
core/rag/datasource/vdb/baidu/baidu_vector.py
core/rag/datasource/vdb/chroma/chroma_vector.py
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
core/rag/datasource/vdb/couchbase/couchbase_vector.py
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
core/rag/datasource/vdb/lindorm/lindorm_vector.py
core/rag/datasource/vdb/matrixone/matrixone_vector.py
core/rag/datasource/vdb/milvus/milvus_vector.py
core/rag/datasource/vdb/myscale/myscale_vector.py
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
core/rag/datasource/vdb/opensearch/opensearch_vector.py
core/rag/datasource/vdb/oracle/oraclevector.py
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
core/rag/datasource/vdb/relyt/relyt_vector.py
core/rag/datasource/vdb/tablestore/tablestore_vector.py
core/rag/datasource/vdb/tencent/tencent_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
core/rag/datasource/vdb/upstash/upstash_vector.py
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
core/rag/datasource/vdb/weaviate/weaviate_vector.py
core/rag/extractor/csv_extractor.py
core/rag/extractor/excel_extractor.py
core/rag/extractor/firecrawl/firecrawl_app.py
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
core/rag/extractor/html_extractor.py
core/rag/extractor/jina_reader_extractor.py
core/rag/extractor/markdown_extractor.py
core/rag/extractor/notion_extractor.py
core/rag/extractor/pdf_extractor.py
core/rag/extractor/text_extractor.py
core/rag/extractor/unstructured/unstructured_doc_extractor.py
core/rag/extractor/unstructured/unstructured_eml_extractor.py
core/rag/extractor/unstructured/unstructured_epub_extractor.py
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
core/rag/extractor/unstructured/unstructured_msg_extractor.py
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
core/rag/extractor/unstructured/unstructured_xml_extractor.py
core/rag/extractor/watercrawl/client.py
core/rag/extractor/watercrawl/extractor.py
core/rag/extractor/watercrawl/provider.py
core/rag/extractor/word_extractor.py
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py
core/rag/retrieval/router/multi_dataset_function_call_router.py
core/rag/summary_index/summary_index.py
core/repositories/sqlalchemy_workflow_execution_repository.py
core/repositories/sqlalchemy_workflow_node_execution_repository.py
core/tools/__base/tool.py
core/tools/mcp_tool/provider.py
core/tools/plugin_tool/provider.py
core/tools/utils/message_transformer.py
core/tools/utils/web_reader_tool.py
core/tools/workflow_as_tool/provider.py
core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py
core/trigger/provider.py
core/workflow/workflow_entry.py
dify_graph/entities/workflow_execution.py
dify_graph/file/file_manager.py
dify_graph/graph_engine/error_handler.py
dify_graph/graph_engine/layers/execution_limits.py
dify_graph/nodes/agent/agent_node.py
dify_graph/nodes/base/node.py
dify_graph/nodes/code/code_node.py
dify_graph/nodes/datasource/datasource_node.py
dify_graph/nodes/document_extractor/node.py
dify_graph/nodes/human_input/human_input_node.py
dify_graph/nodes/if_else/if_else_node.py
dify_graph/nodes/iteration/iteration_node.py
dify_graph/nodes/knowledge_index/knowledge_index_node.py
dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py
dify_graph/nodes/list_operator/node.py
dify_graph/nodes/llm/node.py
dify_graph/nodes/loop/loop_node.py
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
dify_graph/nodes/question_classifier/question_classifier_node.py
dify_graph/nodes/start/start_node.py
dify_graph/nodes/template_transform/template_transform_node.py
dify_graph/nodes/tool/tool_node.py
dify_graph/nodes/trigger_plugin/trigger_event_node.py
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
dify_graph/variables/types.py
extensions/ext_fastopenapi.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py
extensions/storage/aliyun_oss_storage.py
extensions/storage/aws_s3_storage.py
extensions/storage/azure_blob_storage.py
extensions/storage/baidu_obs_storage.py
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
extensions/storage/clickzetta_volume/file_lifecycle.py
extensions/storage/google_cloud_storage.py
extensions/storage/huawei_obs_storage.py
extensions/storage/opendal_storage.py
extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
factories/variable_factory.py
libs/external_api.py
libs/gmpy2_pkcs10aep_cipher.py
libs/helper.py
libs/login.py
libs/module_loading.py
libs/oauth.py
libs/oauth_data_source.py
models/trigger.py
models/workflow.py
repositories/sqlalchemy_api_workflow_node_execution_repository.py
repositories/sqlalchemy_api_workflow_run_repository.py
repositories/sqlalchemy_execution_extra_content_repository.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py
services/auth/firecrawl/firecrawl.py
services/auth/jina.py
services/auth/jina/jina.py
services/auth/watercrawl/watercrawl.py
services/conversation_service.py
services/dataset_service.py
services/document_indexing_proxy/document_indexing_task_proxy.py
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py
services/external_knowledge_service.py
services/plugin/plugin_migration.py
services/recommend_app/buildin/buildin_retrieval.py
services/recommend_app/database/database_retrieval.py
services/recommend_app/remote/remote_retrieval.py
services/summary_index_service.py
services/tools/tools_transform_service.py
services/trigger/trigger_provider_service.py
services/trigger/trigger_subscription_builder_service.py
services/trigger/webhook_service.py
services/workflow_draft_variable_service.py
services/workflow_event_snapshot_service.py
services/workflow_service.py
tasks/app_generate/workflow_execute_task.py
tasks/regenerate_summary_index_task.py
tasks/trigger_processing_tasks.py
tasks/workflow_cfs_scheduler/cfs_scheduler.py
tasks/workflow_execution_tasks.py

View File

@@ -1,8 +0,0 @@
project-includes = ["."]
project-excludes = [
".venv",
"migrations/",
]
python-platform = "linux"
python-version = "3.11.0"
infer-with-first-use = false

View File

@@ -1,5 +1,6 @@
[pytest]
addopts = --cov=./api --cov-report=json
pythonpath = .
addopts = --cov=./api --cov-report=json --import-mode=importlib
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
@@ -19,7 +20,7 @@ env =
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a
MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa
MOCK_SWITCH = true

View File

@@ -21,6 +21,10 @@ celery_redis = Redis(
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
)
logger = logging.getLogger(__name__)

View File

@@ -3,6 +3,7 @@ import math
import time
from collections.abc import Iterable, Sequence
from celery import group
from sqlalchemy import ColumnElement, and_, func, or_, select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
@@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None:
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
enqueued: int = 0
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
if not is_locked:
continue
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
enqueued += 1
if not any(acquired):
continue
jobs = [
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
if is_locked
]
result = group(jobs).apply_async()
enqueued = len(jobs)
logger.info(
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
page + 1,
pages,
len(subscriptions),
sum(1 for x in acquired if x),
enqueued,
result,
)
logger.info("Trigger refresh scan done: due=%d", total_due)

View File

@@ -1,6 +1,6 @@
import logging
from celery import group, shared_task
from celery import current_app, group, shared_task
from sqlalchemy import and_, select
from sqlalchemy.orm import Session, sessionmaker
@@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None:
with session_factory() as session:
total_dispatched = 0
# Process in batches until we've handled all due schedules or hit the limit
while True:
due_schedules = _fetch_due_schedules(session)
if not due_schedules:
break
dispatched_count = _process_schedules(session, due_schedules)
total_dispatched += dispatched_count
with current_app.producer_or_acquire() as producer: # type: ignore
dispatched_count = _process_schedules(session, due_schedules, producer)
total_dispatched += dispatched_count
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if (
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
):
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
if total_dispatched > 0:
logger.info("Total processed: %d dispatched", total_dispatched)
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
@@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
return list(due_schedules)
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
if not schedules:
return 0
@@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan])
if tasks_to_dispatch:
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
job.apply_async()
job.apply_async(producer=producer)
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))

View File

@@ -4,6 +4,7 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
@@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
from models.model import AppModelConfig, IconType
from models.model import AppModelConfig, AppModelConfigDict, IconType
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
@@ -523,7 +524,7 @@ class AppDslService:
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(model_config)
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id

View File

@@ -1,12 +1,12 @@
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from models.model import AppMode
from models.model import AppMode, AppModelConfigDict
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import TypedDict, cast
from typing import Any, TypedDict, cast
import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
@@ -187,7 +187,7 @@ class AppService:
for tool in agent_mode.get("tools") or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
@@ -388,7 +388,7 @@ class AppService:
agent_config = app_model_config.agent_mode_dict
# get all tools
tools = agent_config.get("tools", [])
tools = cast(list[dict[str, Any]], agent_config.get("tools", []))
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"

View File

@@ -2,6 +2,7 @@ import io
import logging
import uuid
from collections.abc import Generator
from typing import cast
from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
@@ -106,7 +107,7 @@ class AudioService:
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = text_to_speech_dict.get("voice")
voice = cast(str | None, text_to_speech_dict.get("voice"))
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(

View File

@@ -63,7 +63,12 @@ class RagPipelineTransformService:
):
node = self._deal_file_extensions(node)
if node.get("data", {}).get("type") == "knowledge-index":
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if dataset.tenant_id != current_user.current_tenant_id:
raise ValueError("Unauthorized")
node = self._deal_knowledge_index(
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
)
new_nodes.append(node)
if new_nodes:
graph["nodes"] = new_nodes
@@ -155,14 +160,13 @@ class RagPipelineTransformService:
def _deal_knowledge_index(
self,
knowledge_configuration: KnowledgeConfiguration,
dataset: Dataset,
doc_form: str,
indexing_technique: str | None,
retrieval_model: RetrievalSetting | None,
node: dict,
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model

View File

@@ -0,0 +1,304 @@
"""
Export app messages to JSONL.GZ format.
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
retriever_resources (from message_metadata), feedback (user feedbacks array).
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
Does NOT touch Message.inputs / Message.user_feedback properties.
"""
import datetime
import gzip
import json
import logging
import tempfile
from collections import defaultdict
from collections.abc import Generator, Iterable
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, cast
import orjson
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, tuple_
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message, MessageFeedback
logger = logging.getLogger(__name__)
MAX_FILENAME_BASE_LENGTH = 1024
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
class AppMessageExportFeedback(BaseModel):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: str
updated_at: str
model_config = ConfigDict(extra="forbid")
class AppMessageExportRecord(BaseModel):
conversation_id: str
message_id: str
query: str
answer: str
inputs: dict[str, Any]
retriever_resources: list[Any] = Field(default_factory=list)
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class AppMessageExportStats(BaseModel):
batches: int = 0
total_messages: int = 0
messages_with_feedback: int = 0
total_feedbacks: int = 0
model_config = ConfigDict(extra="forbid")
class AppMessageExportService:
@staticmethod
def validate_export_filename(filename: str) -> str:
normalized = filename.strip()
if not normalized:
raise ValueError("--filename must not be empty.")
normalized_lower = normalized.lower()
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
if normalized.startswith("/"):
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
if "\\" in normalized:
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
if "//" in normalized:
raise ValueError("--filename must not contain empty path segments ('//').")
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
for ch in normalized:
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
raise ValueError("--filename must not contain control characters or NUL.")
parts = PurePosixPath(normalized).parts
if not parts:
raise ValueError("--filename must include a file name.")
if any(part in (".", "..") for part in parts):
raise ValueError("--filename must not contain '.' or '..' path segments.")
return normalized
@property
def output_gz_name(self) -> str:
return f"{self._filename_base}.jsonl.gz"
@property
def output_jsonl_name(self) -> str:
return f"{self._filename_base}.jsonl"
def __init__(
self,
app_id: str,
end_before: datetime.datetime,
filename: str,
*,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
use_cloud_storage: bool = False,
dry_run: bool = False,
) -> None:
if start_from and start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
self._app_id = app_id
self._end_before = end_before
self._start_from = start_from
self._filename_base = self.validate_export_filename(filename)
self._batch_size = batch_size
self._use_cloud_storage = use_cloud_storage
self._dry_run = dry_run
def run(self) -> AppMessageExportStats:
stats = AppMessageExportStats()
logger.info(
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
self._app_id,
self._start_from,
self._end_before,
self._dry_run,
self._use_cloud_storage,
self.output_gz_name,
)
if self._dry_run:
for _ in self._iter_records_with_stats(stats):
pass
self._finalize_stats(stats)
return stats
if self._use_cloud_storage:
self._export_to_cloud(stats)
else:
self._export_to_local(stats)
self._finalize_stats(stats)
return stats
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
for batch in self._iter_record_batches():
yield from batch
@staticmethod
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
for record in records:
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
def _export_to_local(self, stats: AppMessageExportStats) -> None:
output_path = Path.cwd() / self.output_gz_name
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as output_file:
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
tmp.seek(0)
data = tmp.read()
storage.save(self.output_gz_name, data)
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
for record in self.iter_records():
self._update_stats(stats, record)
yield record
@staticmethod
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
stats.total_messages += 1
if record.feedback:
stats.messages_with_feedback += 1
stats.total_feedbacks += len(record.feedback)
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
if stats.total_messages == 0:
stats.batches = 0
return
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
cursor: tuple[datetime.datetime, str] | None = None
while True:
rows, cursor = self._fetch_batch(cursor)
if not rows:
break
message_ids = [str(row.id) for row in rows]
feedbacks_map = self._fetch_feedbacks(message_ids)
yield [self._build_record(row, feedbacks_map) for row in rows]
def _fetch_batch(
self, cursor: tuple[datetime.datetime, str] | None
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(
Message.id,
Message.conversation_id,
Message.query,
Message.answer,
Message._inputs, # pyright: ignore[reportPrivateUsage]
Message.message_metadata,
Message.created_at,
)
.where(
Message.app_id == self._app_id,
Message.created_at < self._end_before,
)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
stmt = stmt.where(Message.created_at >= self._start_from)
if cursor:
stmt = stmt.where(
tuple_(Message.created_at, Message.id)
> tuple_(
sa.literal(cursor[0], type_=sa.DateTime()),
sa.literal(cursor[1], type_=Message.id.type),
)
)
rows = list(session.execute(stmt).all())
if not rows:
return [], cursor
last = rows[-1]
return rows, (last.created_at, last.id)
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
if not message_ids:
return {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(MessageFeedback)
.where(
MessageFeedback.message_id.in_(message_ids),
MessageFeedback.from_source == "user",
)
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
)
feedbacks = list(session.scalars(stmt).all())
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
for feedback in feedbacks:
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
return result
@staticmethod
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
retriever_resources: list[Any] = []
if row.message_metadata:
try:
metadata = json.loads(row.message_metadata)
value = metadata.get("retriever_resources", [])
if isinstance(value, list):
retriever_resources = value
except (json.JSONDecodeError, TypeError):
pass
message_id = str(row.id)
return AppMessageExportRecord(
conversation_id=str(row.conversation_id),
message_id=message_id,
query=row.query,
answer=row.answer,
inputs=row._inputs if isinstance(row._inputs, dict) else {},
retriever_resources=retriever_resources,
feedback=feedbacks_map.get(message_id, []),
)

View File

@@ -12,6 +12,7 @@ from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import (
App,
AppAnnotationHitHistory,
@@ -142,7 +143,7 @@ class MessagesCleanService:
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
end_before = naive_utc_now() - datetime.timedelta(days=days)
logger.info(
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",

View File

@@ -1,9 +1,10 @@
import logging
import time
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import Any, Protocol
import click
from celery import shared_task
from celery import current_app, shared_task
from configs import dify_config
from core.db.session_factory import session_factory
@@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
class CeleryTaskLike(Protocol):
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
@shared_task(queue="dataset")
def document_indexing_task(dataset_id: str, document_ids: list):
"""
@@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
def _document_indexing_with_tenant_queue(
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
):
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
) -> None:
try:
_document_indexing(dataset_id, document_ids)
except Exception:
@@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue(
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
with current_app.producer_or_acquire() as producer: # type: ignore
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.apply_async(
kwargs={
"tenant_id": document_task.tenant_id,
"dataset_id": document_task.dataset_id,
"document_ids": document_task.document_ids,
},
producer=producer,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()

View File

@@ -14,7 +14,7 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
@shared_task(queue="dataset_summary")
def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None):
"""
Async generate summary index for document segments.

View File

@@ -6,7 +6,6 @@ import typing
import click
from celery import shared_task
from core.helper.marketplace import record_install_plugin_event
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
@@ -166,7 +165,6 @@ def process_tenant_plugin_autoupgrade_check_task(
# execute upgrade
new_unique_identifier = manifest.latest_package_identifier
record_install_plugin_event(new_unique_identifier)
click.echo(
click.style(
f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}",

View File

@@ -3,12 +3,13 @@ import json
import logging
import time
import uuid
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from typing import Any
import click
from celery import shared_task # type: ignore
from celery import group, shared_task
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
@@ -27,6 +28,11 @@ from services.file_service import FileService
logger = logging.getLogger(__name__)
def chunked(iterable: Sequence, size: int):
it = iter(iterable)
return iter(lambda: list(islice(it, size)), [])
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
@@ -83,16 +89,24 @@ def rag_pipeline_run_task(
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
for batch in chunked(next_file_ids, 100):
jobs = []
for next_file_id in batch:
tenant_isolated_task_queue.set_task_waiting_time()
file_id = (
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
)
jobs.append(
rag_pipeline_run_task.s(
rag_pipeline_invoke_entities_file_id=file_id,
tenant_id=tenant_id,
)
)
if jobs:
group(jobs).apply_async()
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()

View File

@@ -16,7 +16,7 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
@shared_task(queue="dataset_summary")
def regenerate_summary_index_task(
dataset_id: str,
regenerate_reason: str = "summary_model_changed",

View File

@@ -5,14 +5,10 @@ This test module validates the 400-character limit enforcement
for App descriptions across all creation and editing endpoints.
"""
import os
import sys
import pytest
# Add the API root to Python path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
class TestAppDescriptionValidationUnit:
"""Unit tests for description validation function"""

View File

@@ -10,8 +10,11 @@ more reliable and realistic test scenarios.
import logging
import os
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Protocol, TypeVar
import psycopg2
import pytest
from flask import Flask
from flask.testing import FlaskClient
@@ -31,6 +34,25 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level
logger = logging.getLogger(__name__)
class _CloserProtocol(Protocol):
"""_Closer is any type which implement the close() method."""
def close(self):
"""close the current object, release any external resouece (file, transaction, connection etc.)
associated with it.
"""
pass
_Closer = TypeVar("_Closer", bound=_CloserProtocol)
@contextmanager
def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]:
yield closer
closer.close()
class DifyTestContainers:
"""
Manages all test containers required for Dify integration tests.
@@ -97,45 +119,28 @@ class DifyTestContainers:
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
logger.info("PostgreSQL container is ready and accepting connections")
# Install uuid-ossp extension for UUID generation
logger.info("Installing uuid-ossp extension...")
try:
import psycopg2
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
cursor.close()
conn.close()
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
with _auto_close(conn):
with conn.cursor() as cursor:
# Install uuid-ossp extension for UUID generation
logger.info("Installing uuid-ossp extension...")
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
logger.info("uuid-ossp extension installed successfully")
except Exception as e:
logger.warning("Failed to install uuid-ossp extension: %s", e)
# Create plugin database for dify-plugin-daemon
logger.info("Creating plugin database...")
try:
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute("CREATE DATABASE dify_plugin;")
cursor.close()
conn.close()
# NOTE: We cannot use `with conn.cursor() as cursor:` as it will wrap the statement
# inside a transaction. However, the `CREATE DATABASE` statement cannot run inside a transaction block.
with _auto_close(conn.cursor()) as cursor:
# Create plugin database for dify-plugin-daemon
logger.info("Creating plugin database...")
cursor.execute("CREATE DATABASE dify_plugin;")
logger.info("Plugin database created successfully")
except Exception as e:
logger.warning("Failed to create plugin database: %s", e)
# Set up storage environment variables
os.environ.setdefault("STORAGE_TYPE", "opendal")
@@ -258,23 +263,16 @@ class DifyTestContainers:
containers = [self.redis, self.postgres, self.dify_sandbox, self.dify_plugin_daemon]
for container in containers:
if container:
try:
container_name = container.image
logger.info("Stopping container: %s", container_name)
container.stop()
logger.info("Successfully stopped container: %s", container_name)
except Exception as e:
# Log error but don't fail the test cleanup
logger.warning("Failed to stop container %s: %s", container, e)
container_name = container.image
logger.info("Stopping container: %s", container_name)
container.stop()
logger.info("Successfully stopped container: %s", container_name)
# Stop and remove the network
if self.network:
try:
logger.info("Removing Docker network...")
self.network.remove()
logger.info("Successfully removed Docker network")
except Exception as e:
logger.warning("Failed to remove Docker network: %s", e)
logger.info("Removing Docker network...")
self.network.remove()
logger.info("Successfully removed Docker network")
self._containers_started = False
logger.info("All test containers stopped and cleaned up successfully")

View File

@@ -0,0 +1,497 @@
"""
Container-backed integration tests for dataset permission services on the real SQL path.
This module exercises persisted DatasetPermission rows and dataset permission
checks with testcontainers-backed infrastructure instead of database-chain mocks.
"""
from uuid import uuid4
import pytest
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
Dataset,
DatasetPermission,
DatasetPermissionEnum,
)
from services.dataset_service import DatasetPermissionService, DatasetService
from services.errors.account import NoPermissionError
class DatasetPermissionTestDataFactory:
"""Create persisted entities and request payloads for dataset permission integration tests."""
@staticmethod
def create_account_with_tenant(
role: TenantAccountRole = TenantAccountRole.NORMAL,
tenant: Tenant | None = None,
) -> tuple[Account, Tenant]:
"""Create a real account and tenant with specified role."""
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
if tenant is None:
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db.session.add_all([account, tenant])
else:
db.session.add(account)
db.session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db.session.add(join)
db.session.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_dataset(
tenant_id: str,
created_by: str,
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
name: str = "Test Dataset",
) -> Dataset:
"""Create a real dataset with specified attributes."""
dataset = Dataset(
tenant_id=tenant_id,
name=name,
description="desc",
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=created_by,
permission=permission,
provider="vendor",
retrieval_model={"top_k": 2},
)
db.session.add(dataset)
db.session.commit()
return dataset
@staticmethod
def create_dataset_permission(
dataset_id: str,
account_id: str,
tenant_id: str,
has_permission: bool = True,
) -> DatasetPermission:
"""Create a real DatasetPermission instance."""
permission = DatasetPermission(
dataset_id=dataset_id,
account_id=account_id,
tenant_id=tenant_id,
has_permission=has_permission,
)
db.session.add(permission)
db.session.commit()
return permission
@staticmethod
def build_user_list_payload(user_ids: list[str]) -> list[dict[str, str]]:
"""Build the request payload shape used by partial-member list updates."""
return [{"user_id": user_id} for user_id in user_ids]
class TestDatasetPermissionServiceGetPartialMemberList:
"""Verify partial-member list reads against persisted DatasetPermission rows."""
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers):
"""
Test retrieving partial member list with multiple members.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
user_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
user_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
user_3, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
expected_account_ids = [user_1.id, user_2.id, user_3.id]
for account_id in expected_account_ids:
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, account_id, tenant.id)
# Act
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
# Assert
assert set(result) == set(expected_account_ids)
assert len(result) == 3
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers):
"""
Test retrieving partial member list with single member.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
expected_account_ids = [user.id]
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
# Act
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
# Assert
assert set(result) == set(expected_account_ids)
assert len(result) == 1
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers):
"""
Test retrieving partial member list when no members exist.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
# Act
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
# Assert
assert result == []
assert len(result) == 0
class TestDatasetPermissionServiceUpdatePartialMemberList:
"""Verify partial-member list updates against persisted DatasetPermission rows."""
def test_update_partial_member_list_add_new_members(self, db_session_with_containers):
"""
Test adding new partial members to a dataset.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
user_list = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
# Act
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list)
# Assert
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert set(result) == {member_1.id, member_2.id}
def test_update_partial_member_list_replace_existing(self, db_session_with_containers):
"""
Test replacing existing partial members with new ones.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
old_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
old_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
new_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
new_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
old_users = DatasetPermissionTestDataFactory.build_user_list_payload([old_member_1.id, old_member_2.id])
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, old_users)
new_users = DatasetPermissionTestDataFactory.build_user_list_payload([new_member_1.id, new_member_2.id])
# Act
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, new_users)
# Assert
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert set(result) == {new_member_1.id, new_member_2.id}
def test_update_partial_member_list_empty_list(self, db_session_with_containers):
"""
Test updating with empty member list (clearing all members).
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
# Act
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, [])
# Assert
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers):
"""
Test error handling and rollback on database error.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
existing_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
replacement_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
DatasetPermissionService.update_partial_member_list(
tenant.id,
dataset.id,
DatasetPermissionTestDataFactory.build_user_list_payload([existing_member.id]),
)
user_list = DatasetPermissionTestDataFactory.build_user_list_payload([replacement_member.id])
rollback_called = {"count": 0}
original_rollback = db.session.rollback
# Act / Assert
with pytest.MonkeyPatch.context() as mp:
def _raise_commit():
raise Exception("Database connection error")
def _rollback_and_mark():
rollback_called["count"] += 1
original_rollback()
mp.setattr("services.dataset_service.db.session.commit", _raise_commit)
mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark)
with pytest.raises(Exception, match="Database connection error"):
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list)
# Assert
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert rollback_called["count"] == 1
assert result == [existing_member.id]
assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 1
class TestDatasetPermissionServiceClearPartialMemberList:
"""Verify partial-member clearing against persisted DatasetPermission rows."""
def test_clear_partial_member_list_success(self, db_session_with_containers):
"""
Test successful clearing of partial member list.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
# Act
DatasetPermissionService.clear_partial_member_list(dataset.id)
# Assert
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_clear_partial_member_list_empty_list(self, db_session_with_containers):
"""
Test clearing partial member list when no members exist.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
# Act
DatasetPermissionService.clear_partial_member_list(dataset.id)
# Assert
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers):
"""
Test error handling and rollback on database error.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
rollback_called = {"count": 0}
original_rollback = db.session.rollback
# Act / Assert
with pytest.MonkeyPatch.context() as mp:
def _raise_commit():
raise Exception("Database connection error")
def _rollback_and_mark():
rollback_called["count"] += 1
original_rollback()
mp.setattr("services.dataset_service.db.session.commit", _raise_commit)
mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark)
with pytest.raises(Exception, match="Database connection error"):
DatasetPermissionService.clear_partial_member_list(dataset.id)
# Assert
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert rollback_called["count"] == 1
assert set(result) == {member_1.id, member_2.id}
assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 2
class TestDatasetServiceCheckDatasetPermission:
"""Verify dataset access checks against persisted partial-member permissions."""
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
"""
Test that user with explicit permission can access partial_members dataset.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id,
owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
# Act (should not raise)
DatasetService.check_dataset_permission(dataset, user)
# Assert
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert user.id in permissions
def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers):
"""
Test error when user without permission tries to access partial_members dataset.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id,
owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
# Act & Assert
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_permission(dataset, user)
class TestDatasetServiceCheckDatasetOperatorPermission:
"""Verify operator permission checks against persisted partial-member permissions."""
def test_check_dataset_operator_permission_partial_members_with_permission_success(
self, db_session_with_containers
):
"""
Test that user with explicit permission can access partial_members dataset.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id,
owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
# Act (should not raise)
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
# Assert
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert user.id in permissions
def test_check_dataset_operator_permission_partial_members_without_permission_error(
self, db_session_with_containers
):
"""
Test error when user without permission tries to access partial_members dataset.
"""
# Arrange
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id,
owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
# Act & Assert
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)

View File

@@ -0,0 +1,233 @@
import datetime
import json
import uuid
from decimal import Decimal
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
class TestAppMessageExportServiceIntegration:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers: Session):
yield
db_session_with_containers.query(DatasetRetrieverResource).delete()
db_session_with_containers.query(AppAnnotationHitHistory).delete()
db_session_with_containers.query(SavedMessage).delete()
db_session_with_containers.query(MessageFile).delete()
db_session_with_containers.query(MessageAgentThought).delete()
db_session_with_containers.query(MessageChain).delete()
db_session_with_containers.query(MessageAnnotation).delete()
db_session_with_containers.query(MessageFeedback).delete()
db_session_with_containers.query(Message).delete()
db_session_with_containers.query(Conversation).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
@staticmethod
def _create_app_context(session: Session) -> tuple[App, Conversation]:
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="tester",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
session.add(tenant)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(join)
session.flush()
app = App(
tenant_id=tenant.id,
name="export-app",
description="integration test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
session.add(app)
session.flush()
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-4o-mini",
mode="chat",
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
session.commit()
return app, conversation
@staticmethod
def _create_message(
session: Session,
app: App,
conversation: Conversation,
created_at: datetime.datetime,
*,
query: str,
answer: str,
inputs: dict,
message_metadata: str | None,
) -> Message:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-4o-mini",
inputs=inputs,
query=query,
answer=answer,
message=[{"role": "assistant", "content": answer}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)
session.add(message)
session.flush()
return message
def test_iter_records_with_stats(self, db_session_with_containers: Session):
app, conversation = self._create_app_context(db_session_with_containers)
first_inputs = {
"plain": "v1",
"nested": {"a": 1, "b": [1, {"x": True}]},
"list": ["x", 2, {"y": "z"}],
}
second_inputs = {"other": "value", "items": [1, 2, 3]}
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
first_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time,
query="q1",
answer="a1",
inputs=first_inputs,
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
)
second_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time + datetime.timedelta(minutes=1),
query="q2",
answer="a2",
inputs=second_inputs,
message_metadata=None,
)
user_feedback_1 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
content="first",
from_end_user_id=conversation.from_end_user_id,
)
user_feedback_2 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
content="second",
from_end_user_id=conversation.from_end_user_id,
)
admin_feedback = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
db_session_with_containers.commit()
service = AppMessageExportService(
app_id=app.id,
start_from=base_time - datetime.timedelta(minutes=1),
end_before=base_time + datetime.timedelta(minutes=10),
filename="unused",
batch_size=1,
dry_run=True,
)
stats = AppMessageExportStats()
records = list(service._iter_records_with_stats(stats))
service._finalize_stats(stats)
assert len(records) == 2
assert records[0].message_id == first_message.id
assert records[1].message_id == second_message.id
assert records[0].inputs == first_inputs
assert records[1].inputs == second_inputs
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
assert records[1].retriever_resources == []
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
assert records[1].feedback == []
assert stats.batches == 2
assert stats.total_messages == 2
assert stats.messages_with_feedback == 1
assert stats.total_feedbacks == 2

View File

@@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_called_once_with(
tenant_id=next_task["tenant_id"],
dataset_id=next_task["dataset_id"],
document_ids=next_task["document_ids"],
)
# apply_async is used by implementation; assert it was called once with expected kwargs
assert task_dispatch_spy.apply_async.call_count == 1
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
assert call_kwargs == {
"tenant_id": next_task["tenant_id"],
"dataset_id": next_task["dataset_id"],
"document_ids": next_task["document_ids"],
}
set_waiting_spy.assert_called_once()
delete_key_spy.assert_not_called()
@@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_not_called()
task_dispatch_spy.apply_async.assert_not_called()
delete_key_spy.assert_called_once()
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
@@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_called_once()
task_dispatch_spy.apply_async.assert_called_once()
def test_sessions_close_on_successful_indexing(
self,
@@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
assert task_dispatch_spy.delay.call_count == concurrency_limit
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
assert set_waiting_spy.call_count == concurrency_limit
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
@@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
assert task_dispatch_spy.delay.call_count == 3
assert task_dispatch_spy.apply_async.call_count == 3
for index, expected_task in enumerate(ordered_tasks):
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
"""Skip limit checks when billing feature is disabled."""

View File

@@ -762,11 +762,12 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify task function was called for each waiting task
assert mock_task_func.delay.call_count == 1
assert mock_task_func.apply_async.call_count == 1
# Verify correct parameters for each call
calls = mock_task_func.delay.call_args_list
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
calls = mock_task_func.apply_async.call_args_list
sent_kwargs = calls[0][1]["kwargs"]
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (tasks were pulled)
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
@@ -830,11 +831,15 @@ class TestDocumentIndexingTasks:
assert updated_document.processing_started_at is not None
# Verify waiting task was still processed despite core processing error
mock_task_func.delay.assert_called_once()
mock_task_func.apply_async.assert_called_once()
# Verify correct parameters for the call
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
call = mock_task_func.apply_async.call_args
assert call[1]["kwargs"] == {
"tenant_id": tenant_id,
"dataset_id": dataset_id,
"document_ids": ["waiting-doc-1"],
}
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -896,9 +901,13 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only tenant1's waiting task was processed
mock_task_func.delay.assert_called_once()
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
mock_task_func.apply_async.assert_called_once()
call = mock_task_func.apply_async.call_args
assert call[1]["kwargs"] == {
"tenant_id": tenant1_id,
"dataset_id": dataset1_id,
"document_ids": ["tenant1-doc-1"],
}
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)

View File

@@ -1,6 +1,6 @@
import json
import uuid
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
@@ -388,8 +388,10 @@ class TestRagPipelineRunTasks:
# Set the task key to indicate there are waiting tasks (legacy behavior)
redis_client.set(legacy_task_key, 1, ex=60 * 60)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the priority task with new code but legacy queue data
rag_pipeline_run_task(file_id, tenant.id)
@@ -398,13 +400,14 @@ class TestRagPipelineRunTasks:
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify waiting tasks were processed via group, pull 1 task a time by default
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert first_kwargs.get("tenant_id") == tenant.id
# Verify that new code can process legacy queue entries
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
@@ -446,8 +449,10 @@ class TestRagPipelineRunTasks:
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
@@ -456,13 +461,14 @@ class TestRagPipelineRunTasks:
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify waiting tasks were processed via group.apply_async
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -557,8 +563,10 @@ class TestRagPipelineRunTasks:
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task (should not raise exception)
rag_pipeline_run_task(file_id, tenant.id)
@@ -569,12 +577,13 @@ class TestRagPipelineRunTasks:
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -684,8 +693,10 @@ class TestRagPipelineRunTasks:
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task for tenant1 only
rag_pipeline_run_task(file_id1, tenant1.id)
@@ -694,11 +705,12 @@ class TestRagPipelineRunTasks:
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert call_kwargs.get("tenant_id") == tenant1.id
# Verify only tenant1's waiting task was processed (via group)
assert mock_group.return_value.apply_async.called
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert first_kwargs.get("tenant_id") == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
@@ -913,8 +925,10 @@ class TestRagPipelineRunTasks:
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act & Assert: Execute the regular task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
rag_pipeline_run_task(file_id, tenant.id)
@@ -924,12 +938,13 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)

View File

@@ -105,18 +105,26 @@ def app_model(
class MockCeleryGroup:
"""Mock for celery group() function that collects dispatched tasks."""
"""Mock for celery group() function that collects dispatched tasks.
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
(e.g. producer) so production code can pass broker-related options without
breaking tests.
"""
def __init__(self) -> None:
self.collected: list[dict[str, Any]] = []
self._applied = False
self.last_apply_async_kwargs: dict[str, Any] | None = None
def __call__(self, items: Any) -> MockCeleryGroup:
self.collected = list(items)
return self
def apply_async(self) -> None:
def apply_async(self, **kwargs: Any) -> None:
# Accept arbitrary kwargs like producer to be compatible with Celery
self._applied = True
self.last_apply_async_kwargs = kwargs
@property
def applied(self) -> bool:

View File

@@ -0,0 +1,181 @@
import datetime
import re
from unittest.mock import MagicMock, patch
import click
import pytest
from commands import clean_expired_messages
def _mock_service() -> MagicMock:
service = MagicMock()
service.run.return_value = {
"batches": 1,
"total_messages": 10,
"filtered_messages": 5,
"total_deleted": 5,
}
return service
def test_absolute_mode_calls_from_time_range():
policy = object()
service = _mock_service()
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 2, 1, 0, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
):
clean_expired_messages.callback(
batch_size=200,
graceful_period=21,
start_from=start_from,
end_before=end_before,
from_days_ago=None,
before_days=None,
dry_run=True,
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=200,
dry_run=True,
)
mock_from_days.assert_not_called()
def test_relative_mode_before_days_only_calls_from_days():
policy = object()
service = _mock_service()
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_days", return_value=service) as mock_from_days,
patch("commands.MessagesCleanService.from_time_range") as mock_from_time_range,
):
clean_expired_messages.callback(
batch_size=500,
graceful_period=14,
start_from=None,
end_before=None,
from_days_ago=None,
before_days=30,
dry_run=False,
)
mock_from_days.assert_called_once_with(
policy=policy,
days=30,
batch_size=500,
dry_run=False,
)
mock_from_time_range.assert_not_called()
def test_relative_mode_with_from_days_ago_calls_from_time_range():
policy = object()
service = _mock_service()
fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
patch("commands.naive_utc_now", return_value=fixed_now),
):
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=None,
end_before=None,
from_days_ago=60,
before_days=30,
dry_run=False,
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=fixed_now - datetime.timedelta(days=60),
end_before=fixed_now - datetime.timedelta(days=30),
batch_size=1000,
dry_run=False,
)
mock_from_days.assert_not_called()
@pytest.mark.parametrize(
("kwargs", "message"),
[
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": datetime.datetime(2024, 2, 1),
"from_days_ago": None,
"before_days": 30,
},
"mutually exclusive",
),
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"Both --start-from and --end-before are required",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 10,
"before_days": None,
},
"--from-days-ago must be used together with --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": -1,
},
"--before-days must be >= 0",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 30,
"before_days": 30,
},
"--from-days-ago must be greater than --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])",
),
],
)
def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str):
with pytest.raises(click.UsageError, match=re.escape(message)):
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=kwargs["start_from"],
end_before=kwargs["end_before"],
from_days_ago=kwargs["from_days_ago"],
before_days=kwargs["before_days"],
dry_run=False,
)

View File

@@ -32,11 +32,6 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs")
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
os.environ.setdefault("STORAGE_TYPE", "opendal")
# Add the API directory to Python path to ensure proper imports
import sys
sys.path.insert(0, PROJECT_DIR)
from core.db.session_factory import configure_session_factory, session_factory
from extensions import ext_redis

View File

@@ -0,0 +1,70 @@
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
RemoteFileUploadError,
TooManyFilesError,
UnsupportedFileTypeError,
)
class TestFilenameNotExistsError:
def test_defaults(self):
error = FilenameNotExistsError()
assert error.code == 400
assert error.description == "The specified filename does not exist."
class TestRemoteFileUploadError:
def test_defaults(self):
error = RemoteFileUploadError()
assert error.code == 400
assert error.description == "Error uploading remote file."
class TestFileTooLargeError:
def test_defaults(self):
error = FileTooLargeError()
assert error.code == 413
assert error.error_code == "file_too_large"
assert error.description == "File size exceeded. {message}"
class TestUnsupportedFileTypeError:
def test_defaults(self):
error = UnsupportedFileTypeError()
assert error.code == 415
assert error.error_code == "unsupported_file_type"
assert error.description == "File type not allowed."
class TestBlockedFileExtensionError:
def test_defaults(self):
error = BlockedFileExtensionError()
assert error.code == 400
assert error.error_code == "file_extension_blocked"
assert error.description == "The file extension is blocked for security reasons."
class TestTooManyFilesError:
def test_defaults(self):
error = TooManyFilesError()
assert error.code == 400
assert error.error_code == "too_many_files"
assert error.description == "Only one file is allowed."
class TestNoFileUploadedError:
def test_defaults(self):
error = NoFileUploadedError()
assert error.code == 400
assert error.error_code == "no_file_uploaded"
assert error.description == "Please upload your file."

View File

@@ -1,22 +1,95 @@
from flask import Response
from controllers.common.file_response import enforce_download_for_html, is_html_content
from controllers.common.file_response import (
_normalize_mime_type,
enforce_download_for_html,
is_html_content,
)
class TestFileResponseHelpers:
def test_is_html_content_detects_mime_type(self):
class TestNormalizeMimeType:
def test_returns_empty_string_for_none(self):
assert _normalize_mime_type(None) == ""
def test_returns_empty_string_for_empty_string(self):
assert _normalize_mime_type("") == ""
def test_normalizes_mime_type(self):
assert _normalize_mime_type("Text/HTML; Charset=UTF-8") == "text/html"
class TestIsHtmlContent:
def test_detects_html_via_mime_type(self):
mime_type = "text/html; charset=UTF-8"
result = is_html_content(mime_type, filename="file.txt", extension="txt")
result = is_html_content(
mime_type=mime_type,
filename="file.txt",
extension="txt",
)
assert result is True
def test_is_html_content_detects_extension(self):
result = is_html_content("text/plain", filename="report.html", extension=None)
def test_detects_html_via_extension_argument(self):
result = is_html_content(
mime_type="text/plain",
filename=None,
extension="html",
)
assert result is True
def test_enforce_download_for_html_sets_headers(self):
def test_detects_html_via_filename_extension(self):
result = is_html_content(
mime_type="text/plain",
filename="report.html",
extension=None,
)
assert result is True
def test_returns_false_when_no_html_detected_anywhere(self):
"""
Missing negative test:
- MIME type is not HTML
- filename has no HTML extension
- extension argument is not HTML
"""
result = is_html_content(
mime_type="application/json",
filename="data.json",
extension="json",
)
assert result is False
def test_returns_false_when_all_inputs_are_none(self):
result = is_html_content(
mime_type=None,
filename=None,
extension=None,
)
assert result is False
class TestEnforceDownloadForHtml:
def test_sets_attachment_when_filename_missing(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
response,
mime_type="text/html",
filename=None,
extension="html",
)
assert updated is True
assert response.headers["Content-Disposition"] == "attachment"
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_sets_headers_when_filename_present(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
@@ -27,11 +100,12 @@ class TestFileResponseHelpers:
)
assert updated is True
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Disposition"].startswith("attachment")
assert "unsafe.html" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_enforce_download_for_html_no_change_for_non_html(self):
def test_does_not_modify_response_for_non_html_content(self):
response = Response("payload", mimetype="text/plain")
updated = enforce_download_for_html(

View File

@@ -0,0 +1,188 @@
from uuid import UUID
import httpx
import pytest
from controllers.common import helpers
from controllers.common.helpers import FileInfo, guess_file_info_from_response
def make_response(
url="https://example.com/file.txt",
headers=None,
content=None,
):
return httpx.Response(
200,
request=httpx.Request("GET", url),
headers=headers or {},
content=content or b"",
)
class TestGuessFileInfoFromResponse:
def test_filename_from_url(self):
response = make_response(
url="https://example.com/test.pdf",
content=b"Hello World",
)
info = guess_file_info_from_response(response)
assert info.filename == "test.pdf"
assert info.extension == ".pdf"
assert info.mimetype == "application/pdf"
def test_filename_from_content_disposition(self):
headers = {
"Content-Disposition": "attachment; filename=myfile.csv",
"Content-Type": "text/csv",
}
response = make_response(
url="https://example.com/",
headers=headers,
content=b"Hello World",
)
info = guess_file_info_from_response(response)
assert info.filename == "myfile.csv"
assert info.extension == ".csv"
assert info.mimetype == "text/csv"
@pytest.mark.parametrize(
("magic_available", "expected_ext"),
[
(True, "txt"),
(False, "bin"),
],
)
def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext):
if magic_available:
if helpers.magic is None:
pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant")
else:
monkeypatch.setattr(helpers, "magic", None)
response = make_response(
url="https://example.com/",
content=b"Hello World",
)
info = guess_file_info_from_response(response)
name, ext = info.filename.split(".")
UUID(name)
assert ext == expected_ext
def test_mimetype_from_header_when_unknown(self):
headers = {"Content-Type": "application/json"}
response = make_response(
url="https://example.com/file.unknown",
headers=headers,
content=b'{"a": 1}',
)
info = guess_file_info_from_response(response)
assert info.mimetype == "application/json"
def test_extension_added_when_missing(self):
headers = {"Content-Type": "image/png"}
response = make_response(
url="https://example.com/image",
headers=headers,
content=b"fakepngdata",
)
info = guess_file_info_from_response(response)
assert info.extension == ".png"
assert info.filename.endswith(".png")
def test_content_length_used_as_size(self):
headers = {
"Content-Length": "1234",
"Content-Type": "text/plain",
}
response = make_response(
url="https://example.com/a.txt",
headers=headers,
content=b"a" * 1234,
)
info = guess_file_info_from_response(response)
assert info.size == 1234
def test_size_minus_one_when_header_missing(self):
response = make_response(url="https://example.com/a.txt")
info = guess_file_info_from_response(response)
assert info.size == -1
def test_fallback_to_bin_extension(self):
headers = {"Content-Type": "application/octet-stream"}
response = make_response(
url="https://example.com/download",
headers=headers,
content=b"\x00\x01\x02\x03",
)
info = guess_file_info_from_response(response)
assert info.extension == ".bin"
assert info.filename.endswith(".bin")
def test_return_type(self):
response = make_response()
info = guess_file_info_from_response(response)
assert isinstance(info, FileInfo)
class TestMagicImportWarnings:
@pytest.mark.parametrize(
("platform_name", "expected_message"),
[
("Windows", "pip install python-magic-bin"),
("Darwin", "brew install libmagic"),
("Linux", "sudo apt-get install libmagic1"),
("Other", "install `libmagic`"),
],
)
def test_magic_import_warning_per_platform(
self,
monkeypatch,
platform_name,
expected_message,
):
import builtins
import importlib
# Force ImportError when "magic" is imported
real_import = builtins.__import__
def fake_import(name, *args, **kwargs):
if name == "magic":
raise ImportError("No module named magic")
return real_import(name, *args, **kwargs)
monkeypatch.setattr(builtins, "__import__", fake_import)
monkeypatch.setattr("platform.system", lambda: platform_name)
# Remove helpers so it imports fresh
import sys
original_helpers = sys.modules.get(helpers.__name__)
sys.modules.pop(helpers.__name__, None)
try:
with pytest.warns(UserWarning, match="To use python-magic") as warning:
imported_helpers = importlib.import_module(helpers.__name__)
assert expected_message in str(warning[0].message)
finally:
if original_helpers is not None:
sys.modules[helpers.__name__] = original_helpers

View File

@@ -0,0 +1,189 @@
import sys
from enum import StrEnum
from unittest.mock import MagicMock, patch
import pytest
from flask_restx import Namespace
from pydantic import BaseModel
class UserModel(BaseModel):
id: int
name: str
class ProductModel(BaseModel):
id: int
price: float
@pytest.fixture(autouse=True)
def mock_console_ns():
"""Mock the console_ns to avoid circular imports during test collection."""
mock_ns = MagicMock(spec=Namespace)
mock_ns.models = {}
# Inject mock before importing schema module
with patch.dict(sys.modules, {"controllers.console": MagicMock(console_ns=mock_ns)}):
yield mock_ns
def test_default_ref_template_value():
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0
assert DEFAULT_REF_TEMPLATE_SWAGGER_2_0 == "#/definitions/{model}"
def test_register_schema_model_calls_namespace_schema_model():
from controllers.common.schema import register_schema_model
namespace = MagicMock(spec=Namespace)
register_schema_model(namespace, UserModel)
namespace.schema_model.assert_called_once()
model_name, schema = namespace.schema_model.call_args.args
assert model_name == "UserModel"
assert isinstance(schema, dict)
assert "properties" in schema
def test_register_schema_model_passes_schema_from_pydantic():
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model
namespace = MagicMock(spec=Namespace)
register_schema_model(namespace, UserModel)
schema = namespace.schema_model.call_args.args[1]
expected_schema = UserModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
assert schema == expected_schema
def test_register_schema_models_registers_multiple_models():
from controllers.common.schema import register_schema_models
namespace = MagicMock(spec=Namespace)
register_schema_models(namespace, UserModel, ProductModel)
assert namespace.schema_model.call_count == 2
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
assert called_names == ["UserModel", "ProductModel"]
def test_register_schema_models_calls_register_schema_model(monkeypatch):
from controllers.common.schema import register_schema_models
namespace = MagicMock(spec=Namespace)
calls = []
def fake_register(ns, model):
calls.append((ns, model))
monkeypatch.setattr(
"controllers.common.schema.register_schema_model",
fake_register,
)
register_schema_models(namespace, UserModel, ProductModel)
assert calls == [
(namespace, UserModel),
(namespace, ProductModel),
]
class StatusEnum(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
class PriorityEnum(StrEnum):
HIGH = "high"
LOW = "low"
def test_get_or_create_model_returns_existing_model(mock_console_ns):
from controllers.common.schema import get_or_create_model
existing_model = MagicMock()
mock_console_ns.models = {"TestModel": existing_model}
result = get_or_create_model("TestModel", {"key": "value"})
assert result == existing_model
mock_console_ns.model.assert_not_called()
def test_get_or_create_model_creates_new_model_when_not_exists(mock_console_ns):
from controllers.common.schema import get_or_create_model
mock_console_ns.models = {}
new_model = MagicMock()
mock_console_ns.model.return_value = new_model
field_def = {"name": {"type": "string"}}
result = get_or_create_model("NewModel", field_def)
assert result == new_model
mock_console_ns.model.assert_called_once_with("NewModel", field_def)
def test_get_or_create_model_does_not_call_model_if_exists(mock_console_ns):
from controllers.common.schema import get_or_create_model
existing_model = MagicMock()
mock_console_ns.models = {"ExistingModel": existing_model}
result = get_or_create_model("ExistingModel", {"key": "value"})
assert result == existing_model
mock_console_ns.model.assert_not_called()
def test_register_enum_models_registers_single_enum():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum)
namespace.schema_model.assert_called_once()
model_name, schema = namespace.schema_model.call_args.args
assert model_name == "StatusEnum"
assert isinstance(schema, dict)
def test_register_enum_models_registers_multiple_enums():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum, PriorityEnum)
assert namespace.schema_model.call_count == 2
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
assert called_names == ["StatusEnum", "PriorityEnum"]
def test_register_enum_models_uses_correct_ref_template():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum)
schema = namespace.schema_model.call_args.args[1]
# Verify the schema contains enum values
assert "enum" in schema or "anyOf" in schema

View File

@@ -0,0 +1,85 @@
"""Shared fixtures for controllers.web unit tests."""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from flask import Flask
@pytest.fixture
def app() -> Flask:
"""Minimal Flask app for request contexts."""
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
class FakeSession:
"""Stand-in for db.session that returns pre-seeded objects by model class name."""
def __init__(self, mapping: dict[str, Any] | None = None):
self._mapping: dict[str, Any] = mapping or {}
self._model_name: str | None = None
def query(self, model: type) -> FakeSession:
self._model_name = model.__name__
return self
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
return self
def first(self) -> Any:
assert self._model_name is not None
return self._mapping.get(self._model_name)
class FakeDB:
"""Minimal db stub exposing engine and session."""
def __init__(self, session: FakeSession | None = None):
self.session = session or FakeSession()
self.engine = object()
def make_app_model(
*,
app_id: str = "app-1",
tenant_id: str = "tenant-1",
mode: str = "chat",
enable_site: bool = True,
status: str = "normal",
) -> SimpleNamespace:
"""Build a fake App model with common defaults."""
tenant = SimpleNamespace(
id=tenant_id,
status="normal",
plan="basic",
custom_config_dict={},
)
return SimpleNamespace(
id=app_id,
tenant_id=tenant_id,
tenant=tenant,
mode=mode,
enable_site=enable_site,
status=status,
workflow=None,
app_model_config=None,
)
def make_end_user(
*,
user_id: str = "end-user-1",
session_id: str = "session-1",
external_user_id: str = "ext-user-1",
) -> SimpleNamespace:
"""Build a fake EndUser model with common defaults."""
return SimpleNamespace(
id=user_id,
session_id=session_id,
external_user_id=external_user_id,
)

View File

@@ -0,0 +1,165 @@
"""Unit tests for controllers.web.app endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission
from controllers.web.error import AppUnavailableError
# ---------------------------------------------------------------------------
# AppParameterApi
# ---------------------------------------------------------------------------
class TestAppParameterApi:
def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None:
features_dict = {"opening_statement": "Hello"}
workflow = SimpleNamespace(
features_dict=features_dict,
user_input_form=lambda to_old_structure=False: [],
)
app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow)
with (
app.test_request_context("/parameters"),
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
patch("controllers.web.app.fields.Parameters") as mock_fields,
):
mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"}
result = AppParameterApi().get(app_model, SimpleNamespace())
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[])
assert result == {"result": "ok"}
def test_workflow_mode_uses_workflow(self, app: Flask) -> None:
features_dict = {}
workflow = SimpleNamespace(
features_dict=features_dict,
user_input_form=lambda to_old_structure=False: [{"var": "x"}],
)
app_model = SimpleNamespace(mode="workflow", workflow=workflow)
with (
app.test_request_context("/parameters"),
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
patch("controllers.web.app.fields.Parameters") as mock_fields,
):
mock_fields.model_validate.return_value.model_dump.return_value = {}
AppParameterApi().get(app_model, SimpleNamespace())
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}])
def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None:
app_model = SimpleNamespace(mode="advanced-chat", workflow=None)
with app.test_request_context("/parameters"):
with pytest.raises(AppUnavailableError):
AppParameterApi().get(app_model, SimpleNamespace())
def test_standard_mode_uses_app_model_config(self, app: Flask) -> None:
config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"})
app_model = SimpleNamespace(mode="chat", app_model_config=config)
with (
app.test_request_context("/parameters"),
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
patch("controllers.web.app.fields.Parameters") as mock_fields,
):
mock_fields.model_validate.return_value.model_dump.return_value = {}
AppParameterApi().get(app_model, SimpleNamespace())
call_kwargs = mock_params.call_args
assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}]
def test_standard_mode_no_config_raises(self, app: Flask) -> None:
app_model = SimpleNamespace(mode="chat", app_model_config=None)
with app.test_request_context("/parameters"):
with pytest.raises(AppUnavailableError):
AppParameterApi().get(app_model, SimpleNamespace())
# ---------------------------------------------------------------------------
# AppMeta
# ---------------------------------------------------------------------------
class TestAppMeta:
@patch("controllers.web.app.AppService")
def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None:
mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}}
app_model = SimpleNamespace(id="app-1")
with app.test_request_context("/meta"):
result = AppMeta().get(app_model, SimpleNamespace())
assert result == {"tool_icons": {}}
# ---------------------------------------------------------------------------
# AppAccessMode
# ---------------------------------------------------------------------------
class TestAppAccessMode:
@patch("controllers.web.app.FeatureService.get_system_features")
def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
with app.test_request_context("/webapp/access-mode?appId=app-1"):
result = AppAccessMode().get()
assert result == {"accessMode": "public"}
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
@patch("controllers.web.app.FeatureService.get_system_features")
def test_returns_access_mode_with_app_id(
self, mock_features: MagicMock, mock_access: MagicMock, app: Flask
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
mock_access.return_value = SimpleNamespace(access_mode="internal")
with app.test_request_context("/webapp/access-mode?appId=app-1"):
result = AppAccessMode().get()
assert result == {"accessMode": "internal"}
mock_access.assert_called_once_with("app-1")
@patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id")
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
@patch("controllers.web.app.FeatureService.get_system_features")
def test_resolves_app_code_to_id(
self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
mock_access.return_value = SimpleNamespace(access_mode="external")
with app.test_request_context("/webapp/access-mode?appCode=code1"):
result = AppAccessMode().get()
mock_resolve.assert_called_once_with("code1")
mock_access.assert_called_once_with("resolved-id")
assert result == {"accessMode": "external"}
@patch("controllers.web.app.FeatureService.get_system_features")
def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
with app.test_request_context("/webapp/access-mode"):
with pytest.raises(ValueError, match="appId or appCode"):
AppAccessMode().get()
# ---------------------------------------------------------------------------
# AppWebAuthPermission
# ---------------------------------------------------------------------------
class TestAppWebAuthPermission:
@patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False)
def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None:
with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}):
result = AppWebAuthPermission().get()
assert result == {"result": True}
def test_raises_when_missing_app_id(self, app: Flask) -> None:
with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}):
with pytest.raises(ValueError, match="appId"):
AppWebAuthPermission().get()

View File

@@ -0,0 +1,135 @@
"""Unit tests for controllers.web.audio endpoints."""
from __future__ import annotations
from io import BytesIO
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.audio import AudioApi, TextApi
from controllers.web.error import (
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
def _app_model() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1", external_user_id="ext-1")
# ---------------------------------------------------------------------------
# AudioApi (audio-to-text)
# ---------------------------------------------------------------------------
class TestAudioApi:
@patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"})
def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
data = {"file": (BytesIO(b"fake-audio"), "test.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
result = AudioApi().post(_app_model(), _end_user())
assert result == {"text": "hello"}
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError())
def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b""), "empty.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(NoAudioUploadedError):
AudioApi().post(_app_model(), _end_user())
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big"))
def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"big"), "big.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(AudioTooLargeError):
AudioApi().post(_app_model(), _end_user())
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError())
def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"bad"), "bad.xyz")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(UnsupportedAudioTypeError):
AudioApi().post(_app_model(), _end_user())
@patch(
"controllers.web.audio.AudioService.transcript_asr",
side_effect=ProviderNotSupportSpeechToTextServiceError(),
)
def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"x"), "x.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(ProviderNotSupportSpeechToTextError):
AudioApi().post(_app_model(), _end_user())
@patch(
"controllers.web.audio.AudioService.transcript_asr",
side_effect=ProviderTokenNotInitError(description="no token"),
)
def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"x"), "x.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(ProviderNotInitializeError):
AudioApi().post(_app_model(), _end_user())
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError())
def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"x"), "x.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(ProviderQuotaExceededError):
AudioApi().post(_app_model(), _end_user())
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError())
def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"x"), "x.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
AudioApi().post(_app_model(), _end_user())
# ---------------------------------------------------------------------------
# TextApi (text-to-audio)
# ---------------------------------------------------------------------------
class TestTextApi:
@patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes")
@patch("controllers.web.audio.web_ns")
def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
mock_ns.payload = {"text": "hello", "voice": "alloy"}
with app.test_request_context("/text-to-audio", method="POST"):
result = TextApi().post(_app_model(), _end_user())
assert result == "audio-bytes"
mock_tts.assert_called_once()
@patch(
"controllers.web.audio.AudioService.transcript_tts",
side_effect=InvokeError(description="invoke failed"),
)
@patch("controllers.web.audio.web_ns")
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
mock_ns.payload = {"text": "hello"}
with app.test_request_context("/text-to-audio", method="POST"):
with pytest.raises(CompletionRequestError):
TextApi().post(_app_model(), _end_user())

View File

@@ -0,0 +1,161 @@
"""Unit tests for controllers.web.completion endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from controllers.web.error import (
CompletionRequestError,
NotChatAppError,
NotCompletionAppError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
def _completion_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="completion")
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# CompletionApi
# ---------------------------------------------------------------------------
class TestCompletionApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/completion-messages", method="POST"):
with pytest.raises(NotCompletionAppError):
CompletionApi().post(_chat_app(), _end_user())
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"})
@patch("controllers.web.completion.AppGenerateService.generate")
@patch("controllers.web.completion.web_ns")
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}, "query": "test"}
mock_gen.return_value = "response-obj"
with app.test_request_context("/completion-messages", method="POST"):
result = CompletionApi().post(_completion_app(), _end_user())
assert result == {"answer": "hi"}
@patch(
"controllers.web.completion.AppGenerateService.generate",
side_effect=ProviderTokenNotInitError(description="not init"),
)
@patch("controllers.web.completion.web_ns")
def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/completion-messages", method="POST"):
with pytest.raises(ProviderNotInitializeError):
CompletionApi().post(_completion_app(), _end_user())
@patch(
"controllers.web.completion.AppGenerateService.generate",
side_effect=QuotaExceededError(),
)
@patch("controllers.web.completion.web_ns")
def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/completion-messages", method="POST"):
with pytest.raises(ProviderQuotaExceededError):
CompletionApi().post(_completion_app(), _end_user())
@patch(
"controllers.web.completion.AppGenerateService.generate",
side_effect=ModelCurrentlyNotSupportError(),
)
@patch("controllers.web.completion.web_ns")
def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/completion-messages", method="POST"):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
CompletionApi().post(_completion_app(), _end_user())
# ---------------------------------------------------------------------------
# CompletionStopApi
# ---------------------------------------------------------------------------
class TestCompletionStopApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
with pytest.raises(NotCompletionAppError):
CompletionStopApi().post(_chat_app(), _end_user(), "task-1")
@patch("controllers.web.completion.AppTaskService.stop_task")
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1")
assert status == 200
assert result == {"result": "success"}
# ---------------------------------------------------------------------------
# ChatApi
# ---------------------------------------------------------------------------
class TestChatApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/chat-messages", method="POST"):
with pytest.raises(NotChatAppError):
ChatApi().post(_completion_app(), _end_user())
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"})
@patch("controllers.web.completion.AppGenerateService.generate")
@patch("controllers.web.completion.web_ns")
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}, "query": "hi"}
mock_gen.return_value = "response"
with app.test_request_context("/chat-messages", method="POST"):
result = ChatApi().post(_chat_app(), _end_user())
assert result == {"answer": "reply"}
@patch(
"controllers.web.completion.AppGenerateService.generate",
side_effect=InvokeError(description="rate limit"),
)
@patch("controllers.web.completion.web_ns")
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}, "query": "x"}
with app.test_request_context("/chat-messages", method="POST"):
with pytest.raises(CompletionRequestError):
ChatApi().post(_chat_app(), _end_user())
# ---------------------------------------------------------------------------
# ChatStopApi
# ---------------------------------------------------------------------------
class TestChatStopApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
with pytest.raises(NotChatAppError):
ChatStopApi().post(_completion_app(), _end_user(), "task-1")
@patch("controllers.web.completion.AppTaskService.stop_task")
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1")
assert status == 200
assert result == {"result": "success"}

View File

@@ -0,0 +1,183 @@
"""Unit tests for controllers.web.conversation endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.conversation import (
ConversationApi,
ConversationListApi,
ConversationPinApi,
ConversationRenameApi,
ConversationUnPinApi,
)
from controllers.web.error import NotChatAppError
from services.errors.conversation import ConversationNotExistsError
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _completion_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="completion")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# ConversationListApi
# ---------------------------------------------------------------------------
class TestConversationListApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/conversations"):
with pytest.raises(NotChatAppError):
ConversationListApi().get(_completion_app(), _end_user())
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
@patch("controllers.web.conversation.db")
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
conv_id = str(uuid4())
conv = SimpleNamespace(
id=conv_id,
name="Test",
inputs={},
status="normal",
introduction="",
created_at=1700000000,
updated_at=1700000000,
)
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
mock_db.engine = "engine"
session_mock = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
with (
app.test_request_context("/conversations?limit=20"),
patch("controllers.web.conversation.Session", return_value=session_ctx),
):
result = ConversationListApi().get(_chat_app(), _end_user())
assert result["limit"] == 20
assert result["has_more"] is False
# ---------------------------------------------------------------------------
# ConversationApi (delete)
# ---------------------------------------------------------------------------
class TestConversationApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}"):
with pytest.raises(NotChatAppError):
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.delete")
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
assert status == 204
assert result["result"] == "success"
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
with pytest.raises(NotFound, match="Conversation Not Exists"):
ConversationApi().delete(_chat_app(), _end_user(), c_id)
# ---------------------------------------------------------------------------
# ConversationRenameApi
# ---------------------------------------------------------------------------
class TestConversationRenameApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
with pytest.raises(NotChatAppError):
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.rename")
@patch("controllers.web.conversation.web_ns")
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "New Name", "auto_generate": False}
conv = SimpleNamespace(
id=str(c_id),
name="New Name",
inputs={},
status="normal",
introduction="",
created_at=1700000000,
updated_at=1700000000,
)
mock_rename.return_value = conv
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}):
result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
assert result["name"] == "New Name"
@patch(
"controllers.web.conversation.ConversationService.rename",
side_effect=ConversationNotExistsError(),
)
@patch("controllers.web.conversation.web_ns")
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "X", "auto_generate": False}
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}):
with pytest.raises(NotFound, match="Conversation Not Exists"):
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
# ---------------------------------------------------------------------------
# ConversationPinApi / ConversationUnPinApi
# ---------------------------------------------------------------------------
class TestConversationPinApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.pin")
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
assert result["result"] == "success"
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
with pytest.raises(NotFound):
ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
class TestConversationUnPinApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.unpin")
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
assert result["result"] == "success"

View File

@@ -0,0 +1,75 @@
"""Unit tests for controllers.web.error HTTP exception classes."""
from __future__ import annotations
import pytest
from controllers.web.error import (
AppMoreLikeThisDisabledError,
AppSuggestedQuestionsAfterAnswerDisabledError,
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
InvalidArgumentError,
InvokeRateLimitError,
NoAudioUploadedError,
NotChatAppError,
NotCompletionAppError,
NotFoundError,
NotWorkflowAppError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
WebAppAuthAccessDeniedError,
WebAppAuthRequiredError,
WebFormRateLimitExceededError,
)
_ERROR_SPECS: list[tuple[type, str, int]] = [
(AppUnavailableError, "app_unavailable", 400),
(NotCompletionAppError, "not_completion_app", 400),
(NotChatAppError, "not_chat_app", 400),
(NotWorkflowAppError, "not_workflow_app", 400),
(ConversationCompletedError, "conversation_completed", 400),
(ProviderNotInitializeError, "provider_not_initialize", 400),
(ProviderQuotaExceededError, "provider_quota_exceeded", 400),
(ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400),
(CompletionRequestError, "completion_request_error", 400),
(AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403),
(AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403),
(NoAudioUploadedError, "no_audio_uploaded", 400),
(AudioTooLargeError, "audio_too_large", 413),
(UnsupportedAudioTypeError, "unsupported_audio_type", 415),
(ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400),
(WebAppAuthRequiredError, "web_sso_auth_required", 401),
(WebAppAuthAccessDeniedError, "web_app_access_denied", 401),
(InvokeRateLimitError, "rate_limit_error", 429),
(WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429),
(NotFoundError, "not_found", 404),
(InvalidArgumentError, "invalid_param", 400),
]
@pytest.mark.parametrize(
("cls", "expected_code", "expected_status"),
_ERROR_SPECS,
ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS],
)
def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None:
"""Each error class exposes the correct error_code and HTTP status code."""
assert cls.error_code == expected_code
assert cls.code == expected_status
def test_error_classes_have_description() -> None:
"""Every error class has a description (string or None for generic errors)."""
# NotFoundError and InvalidArgumentError use None description by design
_NO_DESCRIPTION = {NotFoundError, InvalidArgumentError}
for cls, _, _ in _ERROR_SPECS:
if cls in _NO_DESCRIPTION:
continue
assert isinstance(cls.description, str), f"{cls.__name__} missing description"
assert len(cls.description) > 0, f"{cls.__name__} has empty description"

View File

@@ -0,0 +1,38 @@
"""Unit tests for controllers.web.feature endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from flask import Flask
from controllers.web.feature import SystemFeatureApi
class TestSystemFeatureApi:
@patch("controllers.web.feature.FeatureService.get_system_features")
def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None:
mock_model = MagicMock()
mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
mock_features.return_value = mock_model
with app.test_request_context("/system-features"):
result = SystemFeatureApi().get()
assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
mock_features.assert_called_once()
@patch("controllers.web.feature.FeatureService.get_system_features")
def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None:
"""SystemFeatureApi is unauthenticated by design — no WebApiResource decorator."""
mock_model = MagicMock()
mock_model.model_dump.return_value = {}
mock_features.return_value = mock_model
# Verify it's a bare Resource, not WebApiResource
from flask_restx import Resource
from controllers.web.wraps import WebApiResource
assert issubclass(SystemFeatureApi, Resource)
assert not issubclass(SystemFeatureApi, WebApiResource)

View File

@@ -0,0 +1,89 @@
"""Unit tests for controllers.web.files endpoints."""
from __future__ import annotations
from io import BytesIO
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
)
from controllers.web.files import FileApi
def _app_model() -> SimpleNamespace:
return SimpleNamespace(id="app-1")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
class TestFileApi:
def test_no_file_uploaded(self, app: Flask) -> None:
with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"):
with pytest.raises(NoFileUploadedError):
FileApi().post(_app_model(), _end_user())
def test_too_many_files(self, app: Flask) -> None:
data = {
"file": (BytesIO(b"a"), "a.txt"),
"file2": (BytesIO(b"b"), "b.txt"),
}
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
# Now has "file" key but len(request.files) > 1
with pytest.raises(TooManyFilesError):
FileApi().post(_app_model(), _end_user())
def test_filename_missing(self, app: Flask) -> None:
data = {"file": (BytesIO(b"content"), "")}
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(FilenameNotExistsError):
FileApi().post(_app_model(), _end_user())
@patch("controllers.web.files.FileService")
@patch("controllers.web.files.db")
def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
mock_db.engine = "engine"
from datetime import datetime
upload_file = SimpleNamespace(
id="file-1",
name="test.txt",
size=100,
extension="txt",
mime_type="text/plain",
created_by="eu-1",
created_at=datetime(2024, 1, 1),
)
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
data = {"file": (BytesIO(b"content"), "test.txt")}
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
result, status = FileApi().post(_app_model(), _end_user())
assert status == 201
assert result["id"] == "file-1"
assert result["name"] == "test.txt"
@patch("controllers.web.files.FileService")
@patch("controllers.web.files.db")
def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
import services.errors.file
mock_db.engine = "engine"
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
description="max 10MB"
)
data = {"file": (BytesIO(b"big"), "big.txt")}
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(FileTooLargeError):
FileApi().post(_app_model(), _end_user())

View File

@@ -0,0 +1,156 @@
"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.error import (
AppMoreLikeThisDisabledError,
NotChatAppError,
NotCompletionAppError,
)
from controllers.web.message import (
MessageFeedbackApi,
MessageMoreLikeThisApi,
MessageSuggestedQuestionApi,
)
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _completion_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="completion")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# MessageFeedbackApi
# ---------------------------------------------------------------------------
class TestMessageFeedbackApi:
@patch("controllers.web.message.MessageService.create_feedback")
@patch("controllers.web.message.web_ns")
def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
mock_ns.payload = {"rating": "like", "content": "great"}
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
assert result == {"result": "success"}
mock_create.assert_called_once()
@patch("controllers.web.message.MessageService.create_feedback")
@patch("controllers.web.message.web_ns")
def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
mock_ns.payload = {"rating": None}
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
assert result == {"result": "success"}
@patch(
"controllers.web.message.MessageService.create_feedback",
side_effect=MessageNotExistsError(),
)
@patch("controllers.web.message.web_ns")
def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
mock_ns.payload = {"rating": "dislike"}
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
with pytest.raises(NotFound, match="Message Not Exists"):
MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
# ---------------------------------------------------------------------------
# MessageMoreLikeThisApi
# ---------------------------------------------------------------------------
class TestMessageMoreLikeThisApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
with pytest.raises(NotCompletionAppError):
MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id)
@patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"})
@patch("controllers.web.message.AppGenerateService.generate_more_like_this")
def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
msg_id = uuid4()
mock_gen.return_value = "response"
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
assert result == {"answer": "similar"}
@patch(
"controllers.web.message.AppGenerateService.generate_more_like_this",
side_effect=MessageNotExistsError(),
)
def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
with pytest.raises(NotFound, match="Message Not Exists"):
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
@patch(
"controllers.web.message.AppGenerateService.generate_more_like_this",
side_effect=MoreLikeThisDisabledError(),
)
def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
with pytest.raises(AppMoreLikeThisDisabledError):
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
# ---------------------------------------------------------------------------
# MessageSuggestedQuestionApi
# ---------------------------------------------------------------------------
class TestMessageSuggestedQuestionApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
with pytest.raises(NotChatAppError):
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
def test_wrong_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
with pytest.raises(NotChatAppError):
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
msg_id = uuid4()
mock_suggest.return_value = ["What about X?", "Tell me more about Y."]
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
assert result["data"] == ["What about X?", "Tell me more about Y."]
@patch(
"controllers.web.message.MessageService.get_suggested_questions_after_answer",
side_effect=MessageNotExistsError(),
)
def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
with pytest.raises(NotFound, match="Message not found"):
MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web.error import WebAppAuthRequiredError
from controllers.web.passport import (
PassportService,
decode_enterprise_webapp_user_id,
exchange_token_for_existing_web_user,
generate_session_id,
)
from services.webapp_auth_service import WebAppAuthType
def test_decode_enterprise_webapp_user_id_none() -> None:
assert decode_enterprise_webapp_user_id(None) is None
def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"})
with pytest.raises(Unauthorized):
decode_enterprise_webapp_user_id("token")
def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None:
decoded = {"token_source": "webapp_login_token", "user_id": "u1"}
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded)
assert decode_enterprise_webapp_user_id("token") == decoded
def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None:
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
def _scalar_side_effect(*_args, **_kwargs):
if not hasattr(_scalar_side_effect, "calls"):
_scalar_side_effect.calls = 0
_scalar_side_effect.calls += 1
return site if _scalar_side_effect.calls == 1 else app_model
db_session = SimpleNamespace(scalar=_scalar_side_effect)
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp")
decoded = {"auth_type": "public"}
result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC)
assert result == "resp"
def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None:
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
def _scalar_side_effect(*_args, **_kwargs):
if not hasattr(_scalar_side_effect, "calls"):
_scalar_side_effect.calls = 0
_scalar_side_effect.calls += 1
return site if _scalar_side_effect.calls == 1 else app_model
db_session = SimpleNamespace(scalar=_scalar_side_effect)
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
decoded = {"auth_type": "internal"}
with pytest.raises(WebAppAuthRequiredError):
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL)
def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1")
def _scalar_side_effect(*_args, **_kwargs):
if not hasattr(_scalar_side_effect, "calls"):
_scalar_side_effect.calls = 0
_scalar_side_effect.calls += 1
if _scalar_side_effect.calls == 1:
return site
if _scalar_side_effect.calls == 2:
return app_model
return None
db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None)
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
decoded = {"auth_type": "internal"}
with pytest.raises(NotFound):
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL)
def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
counts = [1, 0]
def _scalar(*_args, **_kwargs):
return counts.pop(0)
db_session = SimpleNamespace(scalar=_scalar)
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
session_id = generate_session_id()
assert session_id

View File

@@ -0,0 +1,423 @@
"""Unit tests for Pydantic models defined in controllers.web modules.
Covers validation logic, field defaults, constraints, and custom validators
for all ~15 Pydantic models across the web controller layer.
"""
from __future__ import annotations
from uuid import uuid4
import pytest
from pydantic import ValidationError
# ---------------------------------------------------------------------------
# app.py models
# ---------------------------------------------------------------------------
from controllers.web.app import AppAccessModeQuery
class TestAppAccessModeQuery:
def test_alias_resolution(self) -> None:
q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"})
assert q.app_id == "abc"
assert q.app_code == "xyz"
def test_defaults_to_none(self) -> None:
q = AppAccessModeQuery.model_validate({})
assert q.app_id is None
assert q.app_code is None
def test_accepts_snake_case(self) -> None:
q = AppAccessModeQuery(app_id="id1", app_code="code1")
assert q.app_id == "id1"
assert q.app_code == "code1"
# ---------------------------------------------------------------------------
# audio.py models
# ---------------------------------------------------------------------------
from controllers.web.audio import TextToAudioPayload
class TestTextToAudioPayload:
def test_defaults(self) -> None:
p = TextToAudioPayload.model_validate({})
assert p.message_id is None
assert p.voice is None
assert p.text is None
assert p.streaming is None
def test_valid_uuid_message_id(self) -> None:
uid = str(uuid4())
p = TextToAudioPayload(message_id=uid)
assert p.message_id == uid
def test_none_message_id_passthrough(self) -> None:
p = TextToAudioPayload(message_id=None)
assert p.message_id is None
def test_invalid_uuid_message_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
TextToAudioPayload(message_id="not-a-uuid")
# ---------------------------------------------------------------------------
# completion.py models
# ---------------------------------------------------------------------------
from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload
class TestCompletionMessagePayload:
def test_defaults(self) -> None:
p = CompletionMessagePayload(inputs={})
assert p.query == ""
assert p.files is None
assert p.response_mode is None
assert p.retriever_from == "web_app"
def test_accepts_full_payload(self) -> None:
p = CompletionMessagePayload(
inputs={"key": "val"},
query="test",
files=[{"id": "f1"}],
response_mode="streaming",
)
assert p.response_mode == "streaming"
assert p.files == [{"id": "f1"}]
def test_invalid_response_mode(self) -> None:
with pytest.raises(ValidationError):
CompletionMessagePayload(inputs={}, response_mode="invalid")
class TestChatMessagePayload:
def test_valid_uuid_fields(self) -> None:
cid = str(uuid4())
pid = str(uuid4())
p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid)
assert p.conversation_id == cid
assert p.parent_message_id == pid
def test_none_uuid_fields(self) -> None:
p = ChatMessagePayload(inputs={}, query="hi")
assert p.conversation_id is None
assert p.parent_message_id is None
def test_invalid_conversation_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
ChatMessagePayload(inputs={}, query="hi", conversation_id="bad")
def test_invalid_parent_message_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad")
def test_query_required(self) -> None:
with pytest.raises(ValidationError):
ChatMessagePayload(inputs={})
# ---------------------------------------------------------------------------
# conversation.py models
# ---------------------------------------------------------------------------
from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload
class TestConversationListQuery:
def test_defaults(self) -> None:
q = ConversationListQuery()
assert q.last_id is None
assert q.limit == 20
assert q.pinned is None
assert q.sort_by == "-updated_at"
def test_limit_lower_bound(self) -> None:
with pytest.raises(ValidationError):
ConversationListQuery(limit=0)
def test_limit_upper_bound(self) -> None:
with pytest.raises(ValidationError):
ConversationListQuery(limit=101)
def test_limit_boundaries_valid(self) -> None:
assert ConversationListQuery(limit=1).limit == 1
assert ConversationListQuery(limit=100).limit == 100
def test_valid_sort_by_options(self) -> None:
for opt in ("created_at", "-created_at", "updated_at", "-updated_at"):
assert ConversationListQuery(sort_by=opt).sort_by == opt
def test_invalid_sort_by(self) -> None:
with pytest.raises(ValidationError):
ConversationListQuery(sort_by="invalid")
def test_valid_last_id(self) -> None:
uid = str(uuid4())
assert ConversationListQuery(last_id=uid).last_id == uid
def test_invalid_last_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
ConversationListQuery(last_id="not-uuid")
class TestConversationRenamePayload:
def test_auto_generate_true_no_name_required(self) -> None:
p = ConversationRenamePayload(auto_generate=True)
assert p.name is None
def test_auto_generate_false_requires_name(self) -> None:
with pytest.raises(ValidationError, match="name is required"):
ConversationRenamePayload(auto_generate=False)
def test_auto_generate_false_blank_name_rejected(self) -> None:
with pytest.raises(ValidationError, match="name is required"):
ConversationRenamePayload(auto_generate=False, name=" ")
def test_auto_generate_false_with_valid_name(self) -> None:
p = ConversationRenamePayload(auto_generate=False, name="My Chat")
assert p.name == "My Chat"
def test_defaults(self) -> None:
p = ConversationRenamePayload(name="test")
assert p.auto_generate is False
assert p.name == "test"
# ---------------------------------------------------------------------------
# message.py models
# ---------------------------------------------------------------------------
from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery
class TestMessageListQuery:
def test_valid_query(self) -> None:
cid = str(uuid4())
q = MessageListQuery(conversation_id=cid)
assert q.conversation_id == cid
assert q.first_id is None
assert q.limit == 20
def test_invalid_conversation_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
MessageListQuery(conversation_id="bad")
def test_limit_bounds(self) -> None:
cid = str(uuid4())
with pytest.raises(ValidationError):
MessageListQuery(conversation_id=cid, limit=0)
with pytest.raises(ValidationError):
MessageListQuery(conversation_id=cid, limit=101)
def test_valid_first_id(self) -> None:
cid = str(uuid4())
fid = str(uuid4())
q = MessageListQuery(conversation_id=cid, first_id=fid)
assert q.first_id == fid
def test_invalid_first_id(self) -> None:
cid = str(uuid4())
with pytest.raises(ValidationError, match="not a valid uuid"):
MessageListQuery(conversation_id=cid, first_id="invalid")
class TestMessageFeedbackPayload:
def test_defaults(self) -> None:
p = MessageFeedbackPayload()
assert p.rating is None
assert p.content is None
def test_valid_ratings(self) -> None:
assert MessageFeedbackPayload(rating="like").rating == "like"
assert MessageFeedbackPayload(rating="dislike").rating == "dislike"
def test_invalid_rating(self) -> None:
with pytest.raises(ValidationError):
MessageFeedbackPayload(rating="neutral")
class TestMessageMoreLikeThisQuery:
def test_valid_modes(self) -> None:
assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking"
assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming"
def test_invalid_mode(self) -> None:
with pytest.raises(ValidationError):
MessageMoreLikeThisQuery(response_mode="invalid")
def test_required(self) -> None:
with pytest.raises(ValidationError):
MessageMoreLikeThisQuery()
# ---------------------------------------------------------------------------
# remote_files.py models
# ---------------------------------------------------------------------------
from controllers.web.remote_files import RemoteFileUploadPayload
class TestRemoteFileUploadPayload:
def test_valid_url(self) -> None:
p = RemoteFileUploadPayload(url="https://example.com/file.pdf")
assert str(p.url) == "https://example.com/file.pdf"
def test_invalid_url(self) -> None:
with pytest.raises(ValidationError):
RemoteFileUploadPayload(url="not-a-url")
def test_url_required(self) -> None:
with pytest.raises(ValidationError):
RemoteFileUploadPayload()
# ---------------------------------------------------------------------------
# saved_message.py models
# ---------------------------------------------------------------------------
from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery
class TestSavedMessageListQuery:
def test_defaults(self) -> None:
q = SavedMessageListQuery()
assert q.last_id is None
assert q.limit == 20
def test_limit_bounds(self) -> None:
with pytest.raises(ValidationError):
SavedMessageListQuery(limit=0)
with pytest.raises(ValidationError):
SavedMessageListQuery(limit=101)
def test_valid_last_id(self) -> None:
uid = str(uuid4())
q = SavedMessageListQuery(last_id=uid)
assert q.last_id == uid
def test_empty_last_id(self) -> None:
q = SavedMessageListQuery(last_id="")
assert q.last_id == ""
class TestSavedMessageCreatePayload:
def test_valid_message_id(self) -> None:
uid = str(uuid4())
p = SavedMessageCreatePayload(message_id=uid)
assert p.message_id == uid
def test_required(self) -> None:
with pytest.raises(ValidationError):
SavedMessageCreatePayload()
# ---------------------------------------------------------------------------
# workflow.py models
# ---------------------------------------------------------------------------
from controllers.web.workflow import WorkflowRunPayload
class TestWorkflowRunPayload:
def test_defaults(self) -> None:
p = WorkflowRunPayload(inputs={})
assert p.inputs == {}
assert p.files is None
def test_with_files(self) -> None:
p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}])
assert p.files == [{"id": "f1"}]
def test_inputs_required(self) -> None:
with pytest.raises(ValidationError):
WorkflowRunPayload()
# ---------------------------------------------------------------------------
# forgot_password.py models
# ---------------------------------------------------------------------------
from controllers.web.forgot_password import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
class TestForgotPasswordSendPayload:
def test_valid_email(self) -> None:
p = ForgotPasswordSendPayload(email="user@example.com")
assert p.email == "user@example.com"
def test_invalid_email(self) -> None:
with pytest.raises(ValidationError, match="not a valid email"):
ForgotPasswordSendPayload(email="not-an-email")
def test_language_optional(self) -> None:
p = ForgotPasswordSendPayload(email="a@b.com")
assert p.language is None
class TestForgotPasswordCheckPayload:
def test_valid(self) -> None:
p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok")
assert p.email == "a@b.com"
assert p.code == "1234"
assert p.token == "tok"
def test_empty_token_rejected(self) -> None:
with pytest.raises(ValidationError):
ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="")
class TestForgotPasswordResetPayload:
def test_valid_passwords(self) -> None:
p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234")
assert p.new_password == "Valid1234"
def test_weak_password_rejected(self) -> None:
with pytest.raises(ValidationError, match="Password must contain"):
ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short")
def test_letters_only_password_rejected(self) -> None:
with pytest.raises(ValidationError, match="Password must contain"):
ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi")
def test_digits_only_password_rejected(self) -> None:
with pytest.raises(ValidationError, match="Password must contain"):
ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789")
# ---------------------------------------------------------------------------
# login.py models
# ---------------------------------------------------------------------------
from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload
class TestLoginPayload:
def test_valid(self) -> None:
p = LoginPayload(email="a@b.com", password="Valid1234")
assert p.email == "a@b.com"
def test_invalid_email(self) -> None:
with pytest.raises(ValidationError, match="not a valid email"):
LoginPayload(email="bad", password="Valid1234")
def test_weak_password(self) -> None:
with pytest.raises(ValidationError, match="Password must contain"):
LoginPayload(email="a@b.com", password="weak")
class TestEmailCodeLoginSendPayload:
def test_valid(self) -> None:
p = EmailCodeLoginSendPayload(email="a@b.com")
assert p.language is None
def test_with_language(self) -> None:
p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans")
assert p.language == "zh-Hans"
class TestEmailCodeLoginVerifyPayload:
def test_valid(self) -> None:
p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok")
assert p.code == "1234"
def test_empty_token_rejected(self) -> None:
with pytest.raises(ValidationError):
EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="")

View File

@@ -0,0 +1,147 @@
"""Unit tests for controllers.web.remote_files endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError
from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi
def _app_model() -> SimpleNamespace:
return SimpleNamespace(id="app-1")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# RemoteFileInfoApi
# ---------------------------------------------------------------------------
class TestRemoteFileInfoApi:
@patch("controllers.web.remote_files.ssrf_proxy")
def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None:
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"}
mock_proxy.head.return_value = mock_resp
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"):
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf")
assert result["file_type"] == "application/pdf"
assert result["file_length"] == 1024
@patch("controllers.web.remote_files.ssrf_proxy")
def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None:
head_resp = MagicMock()
head_resp.status_code = 405 # Method not allowed
get_resp = MagicMock()
get_resp.status_code = 200
get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"}
get_resp.raise_for_status = MagicMock()
mock_proxy.head.return_value = head_resp
mock_proxy.get.return_value = get_resp
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"):
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt")
assert result["file_type"] == "text/plain"
mock_proxy.get.assert_called_once()
# ---------------------------------------------------------------------------
# RemoteFileUploadApi
# ---------------------------------------------------------------------------
class TestRemoteFileUploadApi:
@patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url")
@patch("controllers.web.remote_files.FileService")
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.web_ns")
@patch("controllers.web.remote_files.db")
def test_upload_success(
self,
mock_db: MagicMock,
mock_ns: MagicMock,
mock_proxy: MagicMock,
mock_guess: MagicMock,
mock_file_svc_cls: MagicMock,
mock_signed: MagicMock,
app: Flask,
) -> None:
mock_db.engine = "engine"
mock_ns.payload = {"url": "https://example.com/file.pdf"}
head_resp = MagicMock()
head_resp.status_code = 200
head_resp.content = b"pdf-content"
head_resp.request.method = "HEAD"
mock_proxy.head.return_value = head_resp
get_resp = MagicMock()
get_resp.content = b"pdf-content"
mock_proxy.get.return_value = get_resp
mock_guess.return_value = SimpleNamespace(
filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100
)
mock_file_svc_cls.is_file_size_within_limit.return_value = True
from datetime import datetime
upload_file = SimpleNamespace(
id="f-1",
name="file.pdf",
size=100,
extension="pdf",
mime_type="application/pdf",
created_by="eu-1",
created_at=datetime(2024, 1, 1),
)
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context("/remote-files/upload", method="POST"):
result, status = RemoteFileUploadApi().post(_app_model(), _end_user())
assert status == 201
assert result["id"] == "f-1"
@patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False)
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.web_ns")
def test_file_too_large(
self,
mock_ns: MagicMock,
mock_proxy: MagicMock,
mock_guess: MagicMock,
mock_size_check: MagicMock,
app: Flask,
) -> None:
mock_ns.payload = {"url": "https://example.com/big.zip"}
head_resp = MagicMock()
head_resp.status_code = 200
mock_proxy.head.return_value = head_resp
mock_guess.return_value = SimpleNamespace(
filename="big.zip", extension="zip", mimetype="application/zip", size=999999999
)
with app.test_request_context("/remote-files/upload", method="POST"):
with pytest.raises(FileTooLargeError):
RemoteFileUploadApi().post(_app_model(), _end_user())
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.web_ns")
def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None:
import httpx
mock_ns.payload = {"url": "https://example.com/bad"}
mock_proxy.head.side_effect = httpx.RequestError("connection failed")
with app.test_request_context("/remote-files/upload", method="POST"):
with pytest.raises(RemoteFileUploadError):
RemoteFileUploadApi().post(_app_model(), _end_user())

View File

@@ -0,0 +1,97 @@
"""Unit tests for controllers.web.saved_message endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.error import NotCompletionAppError
from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi
from services.errors.message import MessageNotExistsError
def _completion_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="completion")
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# SavedMessageListApi (GET)
# ---------------------------------------------------------------------------
class TestSavedMessageListApiGet:
def test_non_completion_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/saved-messages"):
with pytest.raises(NotCompletionAppError):
SavedMessageListApi().get(_chat_app(), _end_user())
@patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id")
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[])
with app.test_request_context("/saved-messages?limit=20"):
result = SavedMessageListApi().get(_completion_app(), _end_user())
assert result["limit"] == 20
assert result["has_more"] is False
# ---------------------------------------------------------------------------
# SavedMessageListApi (POST)
# ---------------------------------------------------------------------------
class TestSavedMessageListApiPost:
def test_non_completion_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/saved-messages", method="POST"):
with pytest.raises(NotCompletionAppError):
SavedMessageListApi().post(_chat_app(), _end_user())
@patch("controllers.web.saved_message.SavedMessageService.save")
@patch("controllers.web.saved_message.web_ns")
def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
msg_id = str(uuid4())
mock_ns.payload = {"message_id": msg_id}
with app.test_request_context("/saved-messages", method="POST"):
result = SavedMessageListApi().post(_completion_app(), _end_user())
assert result["result"] == "success"
@patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError())
@patch("controllers.web.saved_message.web_ns")
def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
mock_ns.payload = {"message_id": str(uuid4())}
with app.test_request_context("/saved-messages", method="POST"):
with pytest.raises(NotFound, match="Message Not Exists"):
SavedMessageListApi().post(_completion_app(), _end_user())
# ---------------------------------------------------------------------------
# SavedMessageApi (DELETE)
# ---------------------------------------------------------------------------
class TestSavedMessageApi:
def test_non_completion_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
with pytest.raises(NotCompletionAppError):
SavedMessageApi().delete(_chat_app(), _end_user(), msg_id)
@patch("controllers.web.saved_message.SavedMessageService.delete")
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id)
assert status == 204
assert result["result"] == "success"

View File

@@ -0,0 +1,126 @@
"""Unit tests for controllers.web.site endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.web.site import AppSiteApi, AppSiteInfo
def _tenant(*, status: str = "normal") -> SimpleNamespace:
return SimpleNamespace(
id="tenant-1",
status=status,
plan="basic",
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
)
def _site() -> SimpleNamespace:
return SimpleNamespace(
title="Site",
icon_type="emoji",
icon="robot",
icon_background="#fff",
description="desc",
default_language="en",
chat_color_theme="light",
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
# ---------------------------------------------------------------------------
# AppSiteApi
# ---------------------------------------------------------------------------
class TestAppSiteApi:
@patch("controllers.web.site.FeatureService.get_features")
@patch("controllers.web.site.db")
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
site_obj = _site()
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
result = AppSiteApi().get(app_model, end_user)
# marshal_with serializes AppSiteInfo to a dict
assert result["app_id"] == "app-1"
assert result["plan"] == "basic"
assert result["enable_site"] is True
@patch("controllers.web.site.db")
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_db.session.query.return_value.where.return_value.first.return_value = None
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
@patch("controllers.web.site.db")
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
from models.account import TenantStatus
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
tenant = SimpleNamespace(
id="tenant-1",
status=TenantStatus.ARCHIVE,
plan="basic",
custom_config_dict={},
)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
# ---------------------------------------------------------------------------
# AppSiteInfo
# ---------------------------------------------------------------------------
class TestAppSiteInfo:
def test_basic_fields(self) -> None:
tenant = _tenant()
site_obj = _site()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
assert info.app_id == "app-1"
assert info.end_user_id == "eu-1"
assert info.enable_site is True
assert info.plan == "basic"
assert info.can_replace_logo is False
assert info.model_config is None
@patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com"))
def test_can_replace_logo_sets_custom_config(self) -> None:
tenant = SimpleNamespace(
id="tenant-1",
plan="pro",
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
)
site_obj = _site()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
assert info.can_replace_logo is True
assert info.custom_config["remove_webapp_brand"] is True
assert "webapp-logo" in info.custom_config["replace_webapp_logo"]

View File

@@ -5,7 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
import services.errors.account
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi
def encode_code(code: str) -> str:
@@ -89,3 +90,114 @@ class TestEmailCodeLoginApi:
mock_revoke_token.assert_called_once_with("token-123")
mock_login.assert_called_once()
mock_reset_login_rate.assert_called_once_with("user@example.com")
class TestLoginApi:
@patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok")
@patch("controllers.web.login.WebAppAuthService.authenticate")
def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None:
mock_auth.return_value = MagicMock()
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
response = LoginApi().post()
assert response.get_json()["data"]["access_token"] == "access-tok"
mock_auth.assert_called_once()
@patch(
"controllers.web.login.WebAppAuthService.authenticate",
side_effect=services.errors.account.AccountLoginError(),
)
def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None:
from controllers.console.error import AccountBannedError
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AccountBannedError):
LoginApi().post()
@patch(
"controllers.web.login.WebAppAuthService.authenticate",
side_effect=services.errors.account.AccountPasswordError(),
)
def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None:
from controllers.console.auth.error import AuthenticationFailedError
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AuthenticationFailedError):
LoginApi().post()
class TestLoginStatusApi:
@patch("controllers.web.login.extract_webapp_access_token", return_value=None)
def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None:
with app.test_request_context("/web/login/status"):
result = LoginStatusApi().get()
assert result["logged_in"] is False
assert result["app_logged_in"] is False
@patch("controllers.web.login.decode_jwt_token")
@patch("controllers.web.login.PassportService")
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False)
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
def test_public_app_user_logged_in(
self,
mock_extract: MagicMock,
mock_app_id: MagicMock,
mock_perm: MagicMock,
mock_passport: MagicMock,
mock_decode: MagicMock,
app: Flask,
) -> None:
mock_decode.return_value = (MagicMock(), MagicMock())
with app.test_request_context("/web/login/status?app_code=code1"):
result = LoginStatusApi().get()
assert result["logged_in"] is True
assert result["app_logged_in"] is True
@patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad"))
@patch("controllers.web.login.PassportService")
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True)
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
def test_private_app_passport_fails(
self,
mock_extract: MagicMock,
mock_app_id: MagicMock,
mock_perm: MagicMock,
mock_passport_cls: MagicMock,
mock_decode: MagicMock,
app: Flask,
) -> None:
mock_passport_cls.return_value.verify.side_effect = Exception("bad")
with app.test_request_context("/web/login/status?app_code=code1"):
result = LoginStatusApi().get()
assert result["logged_in"] is False
assert result["app_logged_in"] is False
class TestLogoutApi:
@patch("controllers.web.login.clear_webapp_access_token_from_cookie")
def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None:
with app.test_request_context("/web/logout", method="POST"):
response = LogoutApi().post()
assert response.get_json() == {"result": "success"}
mock_clear.assert_called_once()

View File

@@ -0,0 +1,192 @@
"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web.error import WebAppAuthRequiredError
from controllers.web.passport import (
PassportResource,
decode_enterprise_webapp_user_id,
exchange_token_for_existing_web_user,
generate_session_id,
)
from services.webapp_auth_service import WebAppAuthType
# ---------------------------------------------------------------------------
# decode_enterprise_webapp_user_id
# ---------------------------------------------------------------------------
class TestDecodeEnterpriseWebappUserId:
def test_none_token_returns_none(self) -> None:
assert decode_enterprise_webapp_user_id(None) is None
@patch("controllers.web.passport.PassportService")
def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None:
mock_passport_cls.return_value.verify.return_value = {
"token_source": "webapp_login_token",
"user_id": "u1",
}
result = decode_enterprise_webapp_user_id("valid-jwt")
assert result["user_id"] == "u1"
@patch("controllers.web.passport.PassportService")
def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
mock_passport_cls.return_value.verify.return_value = {
"token_source": "other_source",
}
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
decode_enterprise_webapp_user_id("bad-jwt")
@patch("controllers.web.passport.PassportService")
def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
mock_passport_cls.return_value.verify.return_value = {}
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
decode_enterprise_webapp_user_id("no-source-jwt")
# ---------------------------------------------------------------------------
# generate_session_id
# ---------------------------------------------------------------------------
class TestGenerateSessionId:
@patch("controllers.web.passport.db")
def test_returns_unique_session_id(self, mock_db: MagicMock) -> None:
mock_db.session.scalar.return_value = 0
sid = generate_session_id()
assert isinstance(sid, str)
assert len(sid) == 36 # UUID format
@patch("controllers.web.passport.db")
def test_retries_on_collision(self, mock_db: MagicMock) -> None:
# First call returns count=1 (collision), second returns 0
mock_db.session.scalar.side_effect = [1, 0]
sid = generate_session_id()
assert isinstance(sid, str)
assert mock_db.session.scalar.call_count == 2
# ---------------------------------------------------------------------------
# exchange_token_for_existing_web_user
# ---------------------------------------------------------------------------
class TestExchangeTokenForExistingWebUser:
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.db")
def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
site = SimpleNamespace(code="code1", app_id="app-1")
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
mock_db.session.scalar.side_effect = [site, app_model]
decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external"
with pytest.raises(WebAppAuthRequiredError, match="external"):
exchange_token_for_existing_web_user(
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
)
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.db")
def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
site = SimpleNamespace(code="code1", app_id="app-1")
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
mock_db.session.scalar.side_effect = [site, app_model]
decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal"
with pytest.raises(WebAppAuthRequiredError, match="internal"):
exchange_token_for_existing_web_user(
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL
)
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.db")
def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
mock_db.session.scalar.return_value = None
decoded = {"user_id": "u1", "auth_type": "external"}
with pytest.raises(NotFound):
exchange_token_for_existing_web_user(
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
)
# ---------------------------------------------------------------------------
# PassportResource.get
# ---------------------------------------------------------------------------
class TestPassportResource:
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
with app.test_request_context("/passport"):
with pytest.raises(Unauthorized, match="X-App-Code"):
PassportResource().get()
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.generate_session_id", return_value="new-sess-id")
@patch("controllers.web.passport.db")
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_creates_new_end_user_when_no_user_id(
self,
mock_features: MagicMock,
mock_db: MagicMock,
mock_gen_session: MagicMock,
mock_passport_cls: MagicMock,
app: Flask,
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
site = SimpleNamespace(app_id="app-1", code="code1")
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
mock_db.session.scalar.side_effect = [site, app_model]
mock_passport_cls.return_value.issue.return_value = "issued-token"
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
response = PassportResource().get()
assert response.get_json()["access_token"] == "issued-token"
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.db")
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_reuses_existing_end_user_when_user_id_provided(
self,
mock_features: MagicMock,
mock_db: MagicMock,
mock_passport_cls: MagicMock,
app: Flask,
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
site = SimpleNamespace(app_id="app-1", code="code1")
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing")
mock_db.session.scalar.side_effect = [site, app_model, existing_user]
mock_passport_cls.return_value.issue.return_value = "reused-token"
with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}):
response = PassportResource().get()
assert response.get_json()["access_token"] == "reused-token"
# Should not create a new end user
mock_db.session.add.assert_not_called()
@patch("controllers.web.passport.db")
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
mock_db.session.scalar.return_value = None
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
PassportResource().get()
@patch("controllers.web.passport.db")
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
site = SimpleNamespace(app_id="app-1", code="code1")
disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False)
mock_db.session.scalar.side_effect = [site, disabled_app]
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
PassportResource().get()

View File

@@ -0,0 +1,95 @@
"""Unit tests for controllers.web.workflow endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.error import (
NotWorkflowAppError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError
def _workflow_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="workflow")
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# WorkflowRunApi
# ---------------------------------------------------------------------------
class TestWorkflowRunApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/workflows/run", method="POST"):
with pytest.raises(NotWorkflowAppError):
WorkflowRunApi().post(_chat_app(), _end_user())
@patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"})
@patch("controllers.web.workflow.AppGenerateService.generate")
@patch("controllers.web.workflow.web_ns")
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {"key": "val"}}
mock_gen.return_value = "response"
with app.test_request_context("/workflows/run", method="POST"):
result = WorkflowRunApi().post(_workflow_app(), _end_user())
assert result == {"result": "ok"}
@patch(
"controllers.web.workflow.AppGenerateService.generate",
side_effect=ProviderTokenNotInitError(description="not init"),
)
@patch("controllers.web.workflow.web_ns")
def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/workflows/run", method="POST"):
with pytest.raises(ProviderNotInitializeError):
WorkflowRunApi().post(_workflow_app(), _end_user())
@patch(
"controllers.web.workflow.AppGenerateService.generate",
side_effect=QuotaExceededError(),
)
@patch("controllers.web.workflow.web_ns")
def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/workflows/run", method="POST"):
with pytest.raises(ProviderQuotaExceededError):
WorkflowRunApi().post(_workflow_app(), _end_user())
# ---------------------------------------------------------------------------
# WorkflowTaskStopApi
# ---------------------------------------------------------------------------
class TestWorkflowTaskStopApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
with pytest.raises(NotWorkflowAppError):
WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1")
@patch("controllers.web.workflow.GraphEngineManager.send_stop_command")
@patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check")
def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None:
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1")
assert result == {"result": "success"}
mock_legacy.assert_called_once_with("task-1")
mock_graph.assert_called_once_with("task-1")

View File

@@ -0,0 +1,127 @@
"""Unit tests for controllers.web.workflow_events endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.error import NotFoundError
from controllers.web.workflow_events import WorkflowEventsApi
from models.enums import CreatorUserRole
def _workflow_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# WorkflowEventsApi
# ---------------------------------------------------------------------------
class TestWorkflowEventsApi:
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
mock_db.engine = "engine"
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
with app.test_request_context("/workflow/run-1/events"):
with pytest.raises(NotFoundError):
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
mock_db.engine = "engine"
run = SimpleNamespace(
id="run-1",
app_id="other-app",
created_by_role=CreatorUserRole.END_USER,
created_by="eu-1",
finished_at=None,
)
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
with app.test_request_context("/workflow/run-1/events"):
with pytest.raises(NotFoundError):
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_workflow_run_not_created_by_end_user(
self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask
) -> None:
mock_db.engine = "engine"
run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="eu-1",
finished_at=None,
)
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
with app.test_request_context("/workflow/run-1/events"):
with pytest.raises(NotFoundError):
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
mock_db.engine = "engine"
run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="other-user",
finished_at=None,
)
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
with app.test_request_context("/workflow/run-1/events"):
with pytest.raises(NotFoundError):
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
@patch("controllers.web.workflow_events.WorkflowResponseConverter")
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_finished_run_returns_sse_response(
self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask
) -> None:
from datetime import datetime
mock_db.engine = "engine"
run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="eu-1",
finished_at=datetime(2024, 1, 1),
)
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
finish_response = MagicMock()
finish_response.model_dump.return_value = {"task_id": "run-1"}
finish_response.event.value = "workflow_finished"
mock_converter.workflow_run_result_to_finish_response.return_value = finish_response
with app.test_request_context("/workflow/run-1/events"):
response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
assert response.mimetype == "text/event-stream"

View File

@@ -0,0 +1,393 @@
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
from controllers.web.wraps import (
_validate_user_accessibility,
_validate_webapp_token,
decode_jwt_token,
)
# ---------------------------------------------------------------------------
# _validate_webapp_token
# ---------------------------------------------------------------------------
class TestValidateWebappToken:
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
"""When both flags are true, a non-webapp source must raise."""
decoded = {"token_source": "other"}
with pytest.raises(WebAppAuthRequiredError):
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None:
decoded = {"token_source": "webapp"}
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None:
decoded = {}
with pytest.raises(WebAppAuthRequiredError):
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
def test_public_app_rejects_webapp_source(self) -> None:
"""When auth is not required, a webapp-sourced token must be rejected."""
decoded = {"token_source": "webapp"}
with pytest.raises(Unauthorized):
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
def test_public_app_accepts_non_webapp_source(self) -> None:
decoded = {"token_source": "other"}
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
def test_public_app_accepts_no_source(self) -> None:
decoded = {}
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
def test_system_enabled_but_app_public(self) -> None:
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
decoded = {"token_source": "webapp"}
with pytest.raises(Unauthorized):
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
# ---------------------------------------------------------------------------
# _validate_user_accessibility
# ---------------------------------------------------------------------------
class TestValidateUserAccessibility:
def test_skips_when_auth_disabled(self) -> None:
"""No checks when system or app auth is disabled."""
_validate_user_accessibility(
decoded={},
app_code="code",
app_web_auth_enabled=False,
system_webapp_auth_enabled=False,
webapp_settings=None,
)
def test_missing_user_id_raises(self) -> None:
decoded = {}
with pytest.raises(WebAppAuthRequiredError):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=SimpleNamespace(access_mode="internal"),
)
def test_missing_webapp_settings_raises(self) -> None:
decoded = {"user_id": "u1"}
with pytest.raises(WebAppAuthRequiredError, match="settings not found"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=None,
)
def test_missing_auth_type_raises(self) -> None:
decoded = {"user_id": "u1", "granted_at": 1}
settings = SimpleNamespace(access_mode="public")
with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
def test_missing_granted_at_raises(self) -> None:
decoded = {"user_id": "u1", "auth_type": "external"}
settings = SimpleNamespace(access_mode="public")
with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
def test_external_auth_type_checks_sso_update_time(
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
) -> None:
# granted_at is before SSO update time → denied
mock_sso_time.return_value = datetime.now(UTC)
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
settings = SimpleNamespace(access_mode="public")
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
@patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time")
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
def test_internal_auth_type_checks_workspace_sso_update_time(
self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock
) -> None:
mock_workspace_sso.return_value = datetime.now(UTC)
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted}
settings = SimpleNamespace(access_mode="public")
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
def test_external_auth_passes_when_granted_after_sso_update(
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
) -> None:
mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2)
recent_granted = int(datetime.now(UTC).timestamp())
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
settings = SimpleNamespace(access_mode="public")
# Should not raise
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False)
@patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1")
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True)
def test_permission_check_denies_unauthorized_user(
self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock
) -> None:
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())}
settings = SimpleNamespace(access_mode="internal")
with pytest.raises(WebAppAuthAccessDeniedError):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
# ---------------------------------------------------------------------------
# decode_jwt_token
# ---------------------------------------------------------------------------
class TestDecodeJwtToken:
@patch("controllers.web.wraps._validate_user_accessibility")
@patch("controllers.web.wraps._validate_webapp_token")
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
@patch("controllers.web.wraps.AppService.get_app_id_by_code")
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_happy_path(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
mock_app_id: MagicMock,
mock_access_mode: MagicMock,
mock_validate_token: MagicMock,
mock_validate_user: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
# Configure session mock to return correct objects via scalar()
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, end_user]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
result_app, result_user = decode_jwt_token()
assert result_app.id == "app-1"
assert result_user.id == "eu-1"
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.extract_webapp_passport")
def test_missing_token_raises_unauthorized(
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
mock_extract.return_value = None
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(Unauthorized):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_missing_app_raises_not_found(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
session_mock = MagicMock()
session_mock.scalar.return_value = None # No app found
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_disabled_site_raises_bad_request(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=False)
session_mock = MagicMock()
# scalar calls: app_model, site (code found), then end_user
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(BadRequest, match="Site is disabled"):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_missing_end_user_raises_not_found(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_user_id_mismatch_raises_unauthorized(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, end_user]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(Unauthorized, match="expired"):
decode_jwt_token(user_id="different-user")

View File

@@ -9,8 +9,16 @@ import pytest
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
from core.app.entities.queue_entities import (
QueuePingEvent,
QueueTextChunkEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import StreamEvent
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowExecutionStatus
from models.enums import MessageStatus
from models.execution_extra_content import HumanInputContent
from models.model import EndUser
@@ -185,3 +193,97 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
assert message.answer == "beforeafter"
assert message.status == MessageStatus.NORMAL
def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None:
pipeline = _build_pipeline()
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
pipeline._workflow_id = "workflow-1"
pipeline._ensure_workflow_initialized = mock.Mock()
runtime_state = SimpleNamespace()
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
)
pipeline._workflow_response_converter = mock.Mock()
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
event=StreamEvent.WORKFLOW_FINISHED,
data=SimpleNamespace(status=WorkflowExecutionStatus.SUCCEEDED),
)
event = QueueWorkflowSucceededEvent(outputs={})
responses = list(pipeline._handle_workflow_succeeded_event(event))
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
def test_workflow_partial_success_emits_message_end_before_workflow_finished() -> None:
pipeline = _build_pipeline()
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
pipeline._workflow_id = "workflow-1"
pipeline._ensure_workflow_initialized = mock.Mock()
runtime_state = SimpleNamespace()
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
)
pipeline._workflow_response_converter = mock.Mock()
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
event=StreamEvent.WORKFLOW_FINISHED,
data=SimpleNamespace(status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED),
)
event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
responses = list(pipeline._handle_workflow_partial_success_event(event))
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
def test_process_stream_response_breaks_after_workflow_succeeded() -> None:
pipeline = _build_pipeline()
succeeded_event = QueueWorkflowSucceededEvent(outputs={})
ping_event = QueuePingEvent()
queue_messages = [
SimpleNamespace(event=succeeded_event),
SimpleNamespace(event=ping_event),
]
pipeline._conversation_name_generate_thread = None
pipeline._base_task_pipeline = mock.Mock()
pipeline._base_task_pipeline.queue_manager = mock.Mock()
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
pipeline._handle_workflow_succeeded_event = mock.Mock(
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
)
responses = list(pipeline._process_stream_response())
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
pipeline._handle_workflow_succeeded_event.assert_called_once_with(succeeded_event, trace_manager=None)
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()
def test_process_stream_response_breaks_after_workflow_partial_success() -> None:
pipeline = _build_pipeline()
partial_event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
ping_event = QueuePingEvent()
queue_messages = [
SimpleNamespace(event=partial_event),
SimpleNamespace(event=ping_event),
]
pipeline._conversation_name_generate_thread = None
pipeline._base_task_pipeline = mock.Mock()
pipeline._base_task_pipeline.queue_manager = mock.Mock()
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
pipeline._handle_workflow_partial_success_event = mock.Mock(
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
)
responses = list(pipeline._process_stream_response())
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
pipeline._handle_workflow_partial_success_event.assert_called_once_with(partial_event, trace_manager=None)
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()

View File

@@ -124,12 +124,12 @@ def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
def start(self):
self.started = True
def fake_thread(**kwargs):
def fake_thread(*args, **kwargs):
thread = DummyThread(**kwargs)
captured["thread"] = thread
return thread
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
monkeypatch.setattr(message_cycle_manager, "Timer", fake_thread)
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")

View File

@@ -1,13 +1,8 @@
import sys
import time
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
API_DIR = str(Path(__file__).resolve().parents[5])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)
import dify_graph.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module

View File

@@ -0,0 +1,425 @@
"""
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
This test suite ensures that the files array is correctly populated in the message_end
SSE event, which is critical for vision/image chat responses to render correctly.
Test Coverage:
- Files array populated when MessageFile records exist
- Files array is None when no MessageFile records exist
- Correct signed URL generation for LOCAL_FILE transfer method
- Correct URL handling for REMOTE_URL transfer method
- Correct URL handling for TOOL_FILE transfer method
- Proper file metadata formatting (filename, mime_type, size, extension)
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.app.entities.task_entities import MessageEndStreamResponse
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
class TestMessageEndStreamResponseFiles:
"""Test suite for files array population in message_end SSE event."""
@pytest.fixture
def mock_pipeline(self):
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
pipeline._message_id = str(uuid.uuid4())
pipeline._task_state = Mock()
pipeline._task_state.metadata = Mock()
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
pipeline._task_state.llm_result = Mock()
pipeline._task_state.llm_result.usage = Mock()
pipeline._application_generate_entity = Mock()
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
return pipeline
@pytest.fixture
def mock_message_file_local(self):
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
message_file.upload_file_id = str(uuid.uuid4())
message_file.url = None
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_remote(self):
"""Create a mock MessageFile with REMOTE_URL transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.REMOTE_URL
message_file.upload_file_id = None
message_file.url = "https://example.com/image.jpg"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_tool(self):
"""Create a mock MessageFile with TOOL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.TOOL_FILE
message_file.upload_file_id = None
message_file.url = "tool_file_123.png"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_upload_file(self, mock_message_file_local):
"""Create a mock UploadFile."""
upload_file = Mock(spec=UploadFile)
upload_file.id = mock_message_file_local.upload_file_id
upload_file.name = "test_image.png"
upload_file.mime_type = "image/png"
upload_file.size = 1024
upload_file.extension = "png"
return upload_file
def test_message_end_with_no_files(self, mock_pipeline):
"""Test that files array is None when no MessageFile records exist."""
# Arrange
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = []
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is None
assert result.id == mock_pipeline._message_id
assert result.metadata == {"test": "metadata"}
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_local.id
assert file_dict["filename"] == "test_image.png"
assert file_dict["mime_type"] == "image/png"
assert file_dict["size"] == 1024
assert file_dict["extension"] == ".png"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
assert "https://example.com/signed-url" in file_dict["url"]
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
assert file_dict["remote_url"] == ""
# Verify database queries
# Should be called twice: once for MessageFile, once for UploadFile
assert mock_session.scalars.call_count == 2
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
# Arrange
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_remote]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_remote.id
assert file_dict["filename"] == "image.jpg"
assert file_dict["url"] == "https://example.com/image.jpg"
assert file_dict["extension"] == ".jpg"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
assert file_dict["remote_url"] == "https://example.com/image.jpg"
assert file_dict["upload_file_id"] == mock_message_file_remote.id
# Verify only one query for message_files is made
mock_session.scalars.assert_called_once()
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "https://example.com/tool_file.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["url"] == "https://example.com/tool_file.png"
assert file_dict["filename"] == "tool_file.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with local path."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_123.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
assert file_dict["filename"] == "tool_file_123.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
# Verify tool file signing was called
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_abc.verylongextension"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed.bin"
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
assert result.files is not None
file_dict = result.files[0]
assert file_dict["extension"] == ".bin"
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
def test_message_end_with_multiple_files(
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
):
"""Test that files array contains all MessageFile records when multiple exist."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 2
# Verify both files are present
file_ids = [f["related_id"] for f in result.files]
assert mock_message_file_local.id in file_ids
assert mock_message_file_remote.id in file_ids
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query) - returns empty list (not found)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [] # UploadFile not found
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/fallback-url" in file_dict["url"]
# Verify fallback URL was generated using upload_file_id from message_file
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))

View File

@@ -0,0 +1,84 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
engine = create_engine("sqlite:///:memory:")
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
user = MagicMock(spec=Account)
user.id = str(uuid4())
user.current_tenant_id = str(uuid4())
repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=real_session_factory,
user=user,
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
session_context = MagicMock()
session_context.__enter__.return_value = session
session_context.__exit__.return_value = False
repository._session_factory = MagicMock(return_value=session_context)
return repository
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
return WorkflowExecution.new(
id_=execution_id,
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0.0",
graph={"nodes": [], "edges": []},
inputs={"query": "hello"},
started_at=started_at,
)
def test_save_uses_execution_started_at_when_record_does_not_exist():
session = MagicMock()
session.get.return_value = None
repository = _build_repository_with_mocked_session(session)
started_at = datetime(2026, 1, 1, 12, 0, 0)
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
def test_save_preserves_existing_created_at_when_record_already_exists():
session = MagicMock()
repository = _build_repository_with_mocked_session(session)
execution_id = str(uuid4())
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repository._tenant_id
existing_run.created_at = existing_created_at
session.get.return_value = existing_run
execution = _build_execution(
execution_id=execution_id,
started_at=datetime(2026, 1, 1, 12, 30, 0),
)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()

View File

@@ -4,8 +4,10 @@ from unittest.mock import MagicMock, patch
import pytest
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
from dify_graph.variables.variables import StringVariable
class StubCoordinator:
@@ -278,3 +280,17 @@ class TestGraphRuntimeState:
assert restored_execution.started is True
assert new_stub.state == "configured"
def test_snapshot_restore_preserves_updated_conversation_variable(self):
variable_pool = VariablePool(
conversation_variables=[StringVariable(name="session_name", value="before")],
)
variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after")
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
snapshot = state.dumps()
restored = GraphRuntimeState.from_snapshot(snapshot)
restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name"))
assert restored_value is not None
assert restored_value.value == "after"

Some files were not shown because too many files have changed in this diff Show More