Compare commits

...

174 Commits

Author SHA1 Message Date
JzoNg
1712a2732a Merge branch 'main' into jzh 2026-04-13 17:22:27 +08:00
JzoNg
46bc76bae3 feat(web): add evaluation cell in workflow log 2026-04-13 17:22:04 +08:00
Yunlu Wen
ae898652b2 refactor: move vdb implementations to workspaces (#34900)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
2026-04-13 08:56:43 +00:00
Ke Wang
c34f67495c refactor(api): type WorkflowRun.to_dict with WorkflowRunDict TypedDict (#35047)
Co-authored-by: Ke Wang <ke@pika.art>
2026-04-13 08:30:28 +00:00
hj24
815c536e05 fix: optimize trigger long running read transactions (#35046) 2026-04-13 08:22:54 +00:00
wangxiaolei
fc64427ae1 fix: fix qdrant delete size is too large (#35042)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 07:59:06 +00:00
wdeveloper16
11c518478e test: migrate AudioService TTS message-ID lookup tests to Testcontainers integration tests (#34992)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-13 06:26:43 +00:00
volcano303
e823635ce1 test: migrate app_dsl_service tests to testcontainers (#34429)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 06:25:28 +00:00
sxxtony
98e74c8fde refactor: migrate MessageAnnotation to TypeBase (#34807) 2026-04-13 06:22:43 +00:00
呆萌闷油瓶
29bfa33d59 feat: support ttft report to langfuse (#33344)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-04-13 06:21:58 +00:00
JzoNg
8c6dda125f Merge branch 'main' into jzh 2026-04-13 14:13:32 +08:00
JzoNg
f6047aafe8 feat(web): metric descriptions 2026-04-13 14:13:04 +08:00
LincolnBurrows2017
3ead0beeb1 fix: correct typo submision to submission (#33435)
Co-authored-by: LincolnBurrows2017 <lincoln@example.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-13 05:59:21 +00:00
corevibe555
2108c44c8b refactor(api): consolidate duplicate RerankingModelConfig and WeightedScoreConfig definitions (#34747)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 05:53:45 +00:00
Statxc
b0079e55b4 refactor(api): type WorkflowAppLog.to_dict with WorkflowAppLogDict TypedDict (#34682) 2026-04-13 05:47:44 +00:00
sxxtony
d9f54f8bd7 refactor: migrate WorkflowPause and WorkflowPauseReason to TypeBase (#34688)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 05:46:52 +00:00
JzoNg
dce5715982 Merge branch 'main' into jzh 2026-04-13 13:35:13 +08:00
Jake Armstrong
5a446f8200 refactor(api): deduplicate dataset controller schemas into controller_schemas.py (#34756) 2026-04-13 05:33:20 +00:00
dev-miro26
f4d5e2f43d refactor(api): improve type safety in MCPToolManageService.execute_auth_actions (#34824) 2026-04-13 05:29:10 +00:00
Jake Armstrong
9121f24181 refactor(api): deduplicate TextToAudioPayload and MessageListQuery into controller_schemas.py (#34757) 2026-04-13 05:27:35 +00:00
sxxtony
7dd507af04 refactor: migrate SegmentAttachmentBinding to TypeBase (#34810) 2026-04-13 05:22:43 +00:00
NVIDIAN
3b9aad2ba7 refactor: replace inline api.model response schemas with register_schema_models in activate (#34929)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 05:21:46 +00:00
sxxtony
ea9f74b581 refactor: migrate RecommendedApp to TypeBase (#34808) 2026-04-13 05:19:49 +00:00
NVIDIAN
e37aaa482d refactor: migrate apikey from marshal_with/api.model to Pydantic BaseModel (#34932)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 05:18:42 +00:00
NVIDIAN
a3170f744c refactor: migrate app site from marshal_with/api.model to Pydantic BaseModel (#34933)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-13 05:18:16 +00:00
NVIDIAN
ced3780787 refactor: migrate mcp_server from marshal_with/api.model to Pydantic BaseModel (#34935)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 05:13:53 +00:00
NVIDIAN
6faf26683c refactor: remove marshal_with and inline api.model from app_import (#34934)
Co-authored-by: ai-hpc <ai-hpc@users.noreply.github.com>
2026-04-13 05:13:15 +00:00
dependabot[bot]
8ac9cbf733 chore(deps-dev): bump mypy from 1.20.0 to 1.20.1 in /api in the dev group (#35039)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-13 05:12:23 +00:00
dependabot[bot]
098ed34469 chore(deps): bump weave from 0.52.17 to 0.52.36 in /api in the llm group (#35038)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-13 04:29:31 +00:00
Crazywoola
6cf4d1002f chore: refine .github configs for dependabot, PR template, and stale workflow (#35035)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-13 04:11:31 +00:00
HeYinKazune
a111d56ea3 refactor: use sessionmaker in workflow_tools_manage_service.py (#34896)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 03:47:29 +00:00
wdeveloper16
8436470fcb refactor: replace bare dict with TypedDicts in annotation_service (#34998) 2026-04-13 03:46:33 +00:00
wdeveloper16
17da0e4146 test: migrate BillingService permission-check tests to Testcontainers integration tests (#34993)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 03:44:14 +00:00
wdeveloper16
ea41e9ab4e test: implement Account/Tenant model integration tests to replace db-mocked unit tests (#34994)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 03:39:16 +00:00
dependabot[bot]
5770b5feef chore(deps): bump the opentelemetry group across 1 directory with 16 updates (#35028)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/amd64, ubuntu-latest, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/amd64, ubuntu-latest, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Skip Duplicate Checks (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Run API Tests (push) Has been cancelled
Main CI Pipeline / Skip API Tests (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Run Web Tests (push) Has been cancelled
Main CI Pipeline / Skip Web Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Run Web Full-Stack E2E (push) Has been cancelled
Main CI Pipeline / Skip Web Full-Stack E2E (push) Has been cancelled
Main CI Pipeline / Web Full-Stack E2E (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / Run VDB Tests (push) Has been cancelled
Main CI Pipeline / Skip VDB Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / Run DB Migration Test (push) Has been cancelled
Main CI Pipeline / Skip DB Migration Test (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-13 03:28:46 +00:00
Ygor Leal
b5259a3a85 refactor(api): enable reportUntypedFunctionDecorator in pyright config (#26412) (#35031) 2026-04-13 03:28:23 +00:00
XHamzaX
596559efc9 fix(rag): include is_summary and original_chunk_id in default vector projection (#34950)
Co-authored-by: VFootball Dev <vfootball@example.com>
2026-04-13 03:11:08 +00:00
JzoNg
ea910b8e7d Merge branch 'main' into jzh 2026-04-13 10:29:55 +08:00
JzoNg
e51af66d95 feat(web): support creator filtering in apps & snippets 2026-04-12 12:19:17 +08:00
JzoNg
f93b287949 feat(web): billing for evaluation & snippets 2026-04-12 11:09:15 +08:00
JzoNg
627fbd2e86 feat(web): rag-pipeline evaluation configuration 2026-04-12 10:36:20 +08:00
JzoNg
e4c056a57a Merge branch 'main' into jzh 2026-04-12 10:16:13 +08:00
JzoNg
23291398ec feat(web): switch warining dialog 2026-04-10 18:13:13 +08:00
JzoNg
79fc352a5a feat(web): evaluation run detail 2026-04-10 17:48:28 +08:00
JzoNg
8b6b3cddea refactor(web): rag-pipeline evaluation 2026-04-10 17:31:10 +08:00
JzoNg
d1ca468c1e Merge branch 'main' into jzh 2026-04-10 17:12:57 +08:00
JzoNg
ce28ad771c feat(web): test history of rag-pipeline evaluation 2026-04-10 17:12:27 +08:00
JzoNg
ba951b01de feat(web): template download 2026-04-10 17:05:14 +08:00
JzoNg
670ab16ea1 feat(web): template download & upload & run in rag-pipeline 2026-04-10 16:51:06 +08:00
JzoNg
4680535ecd Merge branch 'main' into jzh 2026-04-10 13:49:26 +08:00
JzoNg
f96e63460e feat(web): save configuration 2026-04-10 13:49:05 +08:00
JzoNg
2df79c0404 refactor: input fields 2026-04-10 12:04:51 +08:00
JzoNg
acef9630d5 feat(web): input fields display 2026-04-10 11:39:35 +08:00
JzoNg
12c3b2e0cd feat(web): start run 2026-04-10 11:26:12 +08:00
JzoNg
577707ae50 refactor(web): input fields 2026-04-10 11:10:24 +08:00
JzoNg
03325e9750 feat(web): run history of batch test 2026-04-10 10:58:40 +08:00
JzoNg
a7ef8f9c12 Merge branch 'main' into jzh 2026-04-10 10:36:45 +08:00
JzoNg
40284d9f95 refactor(web): batch test of evaluation 2026-04-09 21:19:51 +08:00
JzoNg
5efe8b8bd7 feat(web): add condition 2026-04-09 21:08:45 +08:00
JzoNg
8dc6d736ee refactor(web): store of evaluation 2026-04-09 20:32:53 +08:00
JzoNg
5316372772 feat(web): judgement condition 2026-04-09 20:18:25 +08:00
JzoNg
4d1499ef75 refactor(web): refactor condition group 2026-04-09 19:46:18 +08:00
JzoNg
0438285277 Merge branch 'main' into jzh 2026-04-09 19:24:10 +08:00
JzoNg
4879ea5cd5 feat(web): support variable selecting in variable mapping 2026-04-09 19:23:22 +08:00
JzoNg
2a1761ac06 feat(web): add output 2026-04-09 18:16:30 +08:00
JzoNg
c29245c1cb feat(web): only one evaluation workflow can be added 2026-04-09 17:43:34 +08:00
JzoNg
5069694bba refactor(web): remove unused metric property 2026-04-09 17:30:10 +08:00
JzoNg
d1a80a85c0 refactor(web): evaluation configure schema update 2026-04-09 17:17:15 +08:00
JzoNg
5c93d74dec Merge branch 'main' into jzh 2026-04-09 15:36:00 +08:00
JzoNg
e52dbd49be feat(web): dataset evaluation configure 2026-04-09 15:34:59 +08:00
JzoNg
ccc8a5f278 refactor(web): dataset evaluation 2026-04-09 14:56:22 +08:00
JzoNg
cfb5b9dfea feat(web): dataset evaluation configure fetch 2026-04-09 14:21:01 +08:00
JzoNg
73d95245f8 feat(web): dataset evaluation layout 2026-04-09 13:44:29 +08:00
JzoNg
fb91984fcb feat(web): add evaluation navigation for rag-pipeline 2026-04-09 13:26:43 +08:00
JzoNg
29cb1fa12e Merge branch 'main' into jzh 2026-04-09 13:15:20 +08:00
JzoNg
78240ed199 Merge branch 'main' into jzh 2026-04-09 09:07:12 +08:00
JzoNg
8f8707fd77 Merge branch 'main' into jzh 2026-04-07 16:57:37 +08:00
JzoNg
ed3db06154 feat(web): restrictions of evalution workflow available nodes 2026-04-07 16:12:25 +08:00
JzoNg
7c05a68876 Merge branch 'main' into jzh 2026-04-07 14:41:42 +08:00
JzoNg
6cfc0dd8e1 Merge branch 'main' into jzh 2026-04-07 12:52:13 +08:00
JzoNg
81baeae5c4 fix(web): evaluation workflow switch 2026-04-03 18:22:44 +08:00
JzoNg
a3010bdc0b Merge branch 'main' into jzh 2026-04-03 18:05:54 +08:00
JzoNg
8133e550ed chore: fix pre-hook of web 2026-04-03 16:21:32 +08:00
JzoNg
2bb0eab636 chore(web): mapping row refactor 2026-04-03 16:10:41 +08:00
JzoNg
5311b5d00d feat(web): available evaluation workflow selector 2026-04-03 16:06:33 +08:00
JzoNg
9b02ccdd12 Merge branch 'main' into jzh 2026-04-03 15:15:11 +08:00
JzoNg
231783eebe chore(web): fix lint 2026-04-03 15:13:52 +08:00
JzoNg
756606f478 feat(web): hide card view in evaluation 2026-04-03 14:39:41 +08:00
JzoNg
6651c1c5da feat(web): workflow switch 2026-04-03 14:22:50 +08:00
JzoNg
61e257b2a8 feat(web): app switch api 2026-04-03 13:56:00 +08:00
JzoNg
3ac4caf735 Merge branch 'main' into jzh 2026-04-03 11:28:22 +08:00
JzoNg
268ae1751d Merge branch 'main' into jzh 2026-04-01 09:26:13 +08:00
JzoNg
015cbf850b Merge branch 'main' into jzh 2026-03-31 18:08:24 +08:00
JzoNg
873e13c2fb feat(web): support select node in metric card 2026-03-31 18:07:52 +08:00
JzoNg
688bf7e7a1 feat(web): metric card style 2026-03-31 17:43:56 +08:00
JzoNg
a6ffff3b39 fix(web): fix style of metric selector 2026-03-31 17:22:07 +08:00
JzoNg
023fc55bd5 fix(web): empty state of metric 2026-03-31 17:11:44 +08:00
JzoNg
351b909a53 feat(web): metric card 2026-03-31 17:00:37 +08:00
JzoNg
6bec4f65c9 refactor(web): metric section refactor 2026-03-31 16:28:48 +08:00
JzoNg
74f87ce152 Merge branch 'main' into jzh 2026-03-31 16:13:04 +08:00
JzoNg
92c472ccc7 Merge branch 'main' into jzh 2026-03-30 15:40:23 +08:00
JzoNg
b92b8becd1 feat(web): metric selector 2026-03-30 15:39:52 +08:00
JzoNg
23d0d6a65d chore(web): i18n of metrics 2026-03-30 14:20:43 +08:00
JzoNg
1660067d6e feat(web): judgement model selector 2026-03-30 14:03:37 +08:00
JzoNg
0642475b85 Merge branch 'main' into jzh 2026-03-30 13:30:10 +08:00
JzoNg
8cb634c9bc feat(web): evaluation layout 2026-03-30 11:27:06 +08:00
JzoNg
768b41c3cf Merge branch 'main' into jzh 2026-03-30 11:07:42 +08:00
JzoNg
ca88516d54 refactor(web): refactor evaluation page 2026-03-30 11:06:41 +08:00
JzoNg
871a2a149f refactor(web): split snippet index 2026-03-30 10:32:59 +08:00
JzoNg
60e381eff0 Merge branch 'main' into jzh 2026-03-30 09:48:58 +08:00
JzoNg
768b3eb6f9 feat(web): test run of snippet 2026-03-29 20:55:11 +08:00
JzoNg
2f88da4a6d feat(web): add variable inspect for snippet 2026-03-29 20:23:24 +08:00
JzoNg
a8cdf6964c feat(web): test run button 2026-03-29 20:02:59 +08:00
JzoNg
985c3db4fd feat(web): snippet input field panel layout 2026-03-29 18:02:27 +08:00
JzoNg
9636472db7 refactor(web): snippet main 2026-03-29 17:50:30 +08:00
JzoNg
0ad268aa7d feat(web): snippet publish 2026-03-29 17:29:37 +08:00
JzoNg
a4ea33167d feat(web): block selector in snippet 2026-03-29 17:01:32 +08:00
JzoNg
0f13aabea8 feat(web): input fields in snippet 2026-03-29 16:31:38 +08:00
JzoNg
1e76ef5ccb chore(web): ignore system vars & conversation vars in rag-pipeline and snippet 2026-03-29 15:56:24 +08:00
JzoNg
e6e3229d17 feat(web): input field button style 2026-03-29 15:45:05 +08:00
JzoNg
dccf8e723a feat(web): snippet version panel 2026-03-29 15:26:59 +08:00
JzoNg
c41ba7d627 feat(web): snippet header in graph 2026-03-29 15:02:34 +08:00
JzoNg
a6e9316de3 Merge branch 'main' into jzh 2026-03-29 14:07:49 +08:00
JzoNg
559d326cbd chore(web): mock data of snippet 2026-03-27 17:24:01 +08:00
JzoNg
abedf2506f Merge branch 'main' into jzh 2026-03-27 17:01:27 +08:00
JzoNg
d01428b5bc feat(web): snippet graph draft sync 2026-03-27 16:02:47 +08:00
JzoNg
0de1f17e5c Merge branch 'main' into jzh 2026-03-27 15:23:49 +08:00
JzoNg
17d07a5a43 feat(web): init snippet graph 2026-03-27 15:23:03 +08:00
JzoNg
3bdbea99a3 Merge branch 'main' into jzh 2026-03-27 14:04:10 +08:00
JzoNg
b7683aedb1 Merge branch 'main' into jzh 2026-03-26 21:38:48 +08:00
JzoNg
515036e758 test(web): add tests for snippets 2026-03-26 21:38:22 +08:00
JzoNg
22b382527f feat(web): add snippet to workflow 2026-03-26 21:26:29 +08:00
JzoNg
2cfe4b5b86 feat(web): snippet graph data fetching 2026-03-26 21:11:09 +08:00
JzoNg
6876c8041c feat(web): snippet list data fetching in block selector 2026-03-26 20:58:42 +08:00
JzoNg
7de45584ce refactor: snippets list 2026-03-26 20:41:51 +08:00
JzoNg
5572d7c7e8 Merge branch 'main' into jzh 2026-03-26 20:10:47 +08:00
JzoNg
db0a2fe52e Merge branch 'main' into jzh 2026-03-26 16:29:44 +08:00
JzoNg
f0ae8d6167 fix(web): unused imports caused by merge 2026-03-26 16:28:56 +08:00
JzoNg
2514e181ba Merge branch 'main' into jzh 2026-03-26 16:16:10 +08:00
JzoNg
be2e6e9a14 Merge branch 'main' into jzh 2026-03-26 14:23:29 +08:00
JzoNg
875e2eac1b Merge branch 'main' into jzh 2026-03-26 08:38:57 +08:00
JzoNg
c3c73ceb1f Merge branch 'main' into jzh 2026-03-25 23:02:18 +08:00
JzoNg
6318bf0a2a feat(web): create snippet from workflow 2026-03-25 22:57:48 +08:00
JzoNg
5e1f252046 feat(web): selection context menu style update 2026-03-25 22:36:27 +08:00
JzoNg
df3b960505 fix(web): position of selection context menu in workflow graph 2026-03-25 22:02:50 +08:00
JzoNg
26bc108bf1 chore(web): tests for snippet info 2026-03-25 21:35:36 +08:00
JzoNg
a5cff32743 feat(web): snippet info operations 2026-03-25 21:29:06 +08:00
JzoNg
d418dd8eec Merge branch 'main' into jzh 2026-03-25 20:17:32 +08:00
JzoNg
61702fe346 Merge branch 'main' into jzh 2026-03-25 18:17:03 +08:00
JzoNg
43f0c780c3 Merge branch 'main' into jzh 2026-03-25 15:30:21 +08:00
JzoNg
30ebf2bfa9 Merge branch 'main' into jzh 2026-03-24 07:25:22 +08:00
JzoNg
7e3027b5f7 feat(web): snippet card usage info 2026-03-23 17:02:00 +08:00
JzoNg
b3acf83090 Merge branch 'main' into jzh 2026-03-23 16:46:26 +08:00
JzoNg
36c3d6e48a feat(web): snippet list fetching & display 2026-03-23 16:37:05 +08:00
JzoNg
f782ac6b3c feat(web): create snippets by DSL import 2026-03-23 14:55:36 +08:00
JzoNg
feef2dd1fa feat(web): add snippet creation dialog flow 2026-03-23 11:29:41 +08:00
JzoNg
a716d8789d refactor: extract snippet list components 2026-03-23 10:48:15 +08:00
JzoNg
6816f89189 Merge branch 'main' into jzh 2026-03-23 10:13:45 +08:00
JzoNg
bfcac64a9d Merge branch 'main' into jzh 2026-03-20 15:33:49 +08:00
JzoNg
664eb601a2 feat(web): add api of snippet worfklows 2026-03-20 15:29:53 +08:00
JzoNg
8e5cc4e0aa feat(web): add evaluation api 2026-03-20 15:23:03 +08:00
JzoNg
9f28575903 feat(web): add snippets api 2026-03-20 15:11:33 +08:00
JzoNg
4b9a26a5e6 Merge branch 'main' into jzh 2026-03-20 14:01:34 +08:00
JzoNg
7b85adf1cc Merge branch 'main' into jzh 2026-03-20 10:46:45 +08:00
JzoNg
c964708ebe Merge branch 'main' into jzh 2026-03-18 18:07:20 +08:00
JzoNg
883eb498c0 Merge branch 'main' into jzh 2026-03-18 17:40:51 +08:00
JzoNg
4d3738d225 Merge branch 'main' into feat/evaluation-fe 2026-03-17 10:42:44 +08:00
JzoNg
dd0dee739d Merge branch 'main' into jzh 2026-03-16 15:43:20 +08:00
zxhlyh
4d19914fcb Merge branch 'main' into feat/evaluation-fe 2026-03-16 10:47:37 +08:00
zxhlyh
887c7710e9 feat: evaluation 2026-03-16 10:46:33 +08:00
zxhlyh
7a722773c7 feat: snippet canvas 2026-03-13 17:45:04 +08:00
zxhlyh
a763aff58b feat: snippets list 2026-03-13 16:12:42 +08:00
zxhlyh
c1011f4e5c feat: add to snippet 2026-03-13 14:29:59 +08:00
zxhlyh
f7afa103a5 feat: select snippets 2026-03-13 13:43:29 +08:00
489 changed files with 24305 additions and 4224 deletions

100
.github/dependabot.yml vendored
View File

@@ -1,106 +1,6 @@
version: 2
updates:
- package-ecosystem: "pip"
directory: "/api"
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
open-pull-requests-limit: 10

View File

@@ -18,7 +18,7 @@
## Checklist
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [ ] I've updated the documentation accordingly.
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods

View File

@@ -92,6 +92,7 @@ jobs:
vdb:
- 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'

View File

@@ -23,8 +23,8 @@ jobs:
days-before-issue-stale: 15
days-before-issue-close: 3
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'
any-of-labels: '🌚 invalid,🙋‍♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'

View File

@@ -89,7 +89,7 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@@ -81,12 +81,12 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate
api/providers/vdb/vdb-chroma/tests/integration_tests \
api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/providers/vdb/vdb-weaviate/tests/integration_tests

View File

@@ -77,7 +77,7 @@ if $web_modified; then
fi
cd ./web || exit 1
vp staged
pnpm exec vp staged
if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo"

View File

@@ -21,8 +21,9 @@ RUN apt-get update \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
# Install Python dependencies (workspace members under providers/vdb/)
COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev
# production stage

View File

@@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
import qdrant_client
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.")

View File

@@ -1,4 +1,3 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
@@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
default="public",
)
HOLOGRES_TOKENIZER: TokenizerType = Field(
HOLOGRES_TOKENIZER: str = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
HOLOGRES_DISTANCE_METHOD: str = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)

View File

@@ -1,4 +1,5 @@
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field, model_validator
@@ -23,9 +24,9 @@ class ConversationRenamePayload(BaseModel):
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID")
first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
class MessageFeedbackPayload(BaseModel):
@@ -69,11 +70,35 @@ class WorkflowUpdatePayload(BaseModel):
marked_comment: str | None = Field(default=None, max_length=100)
# --- Dataset schemas ---
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
class MetadataUpdatePayload(BaseModel):
name: str
# --- Audio schemas ---
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")

View File

@@ -1,12 +1,16 @@
from datetime import datetime
import flask_restx
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import field_validator
from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from libs.helper import TimestampField
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
@@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
}
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
class ApiKeyItem(ResponseModel):
id: str
type: str
token: str
last_used_at: int | None = None
created_at: int | None = None
@field_validator("last_used_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
def _get_resource(resource_id, tenant_id, resource_model):
@@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
@marshal_with(api_key_list_model)
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
@@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return {"items": keys}
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@marshal_with(api_key_item_model)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
@@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 201
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
class BaseApiKeyResource(Resource):
@@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_app_api_keys")
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "Success", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return super().get(resource_id)
@@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
@@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "Success", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return super().get(resource_id)
@@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""

View File

@@ -25,7 +25,13 @@ from fields.annotation_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
UpdateAnnotationArgs,
UpdateAnnotationSettingArgs,
UpsertAnnotationArgs,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource):
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
enable_args: EnableAnnotationArgs = {
"score_threshold": args.score_threshold,
"embedding_provider_name": args.embedding_provider_name,
"embedding_model_name": args.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource):
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
return result, 200
@@ -237,8 +249,16 @@ class AnnotationApi(Resource):
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
upsert_args: UpsertAnnotationArgs = {}
if args.answer is not None:
upsert_args["answer"] = args.answer
if args.content is not None:
upsert_args["content"] = args.content
if args.message_id is not None:
upsert_args["message_id"] = args.message_id
if args.question is not None:
upsert_args["question"] = args.question
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
update_args: UpdateAnnotationArgs = {}
if args.answer is not None:
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required

View File

@@ -1,7 +1,8 @@
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@@ -10,35 +11,15 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import (
app_import_check_dependencies_fields,
app_import_fields,
leaked_dependency_fields,
)
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService
from services.app_dsl_service import AppDslService, Import
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportStatus
from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus
from services.feature_service import FeatureService
from .. import console_ns
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
app_import_model = console_ns.model("AppImport", app_import_fields)
# For nested models, need to replace nested dict with registered model
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
@@ -52,18 +33,18 @@ class AppImportPayload(BaseModel):
app_id: str | None = Field(None)
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
@console_ns.response(200, "Import completed", console_ns.models[Import.__name__])
@console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__])
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@@ -104,10 +85,11 @@ class AppImportApi(Resource):
@console_ns.route("/apps/imports/<string:import_id>/confirm")
class AppImportConfirmApi(Resource):
@console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__])
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@edit_permission_required
def post(self, import_id):
# Check user role first
@@ -128,11 +110,11 @@ class AppImportConfirmApi(Resource):
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
class AppImportCheckDependenciesApi(Resource):
@console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__])
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_model)
@edit_permission_required
def get(self, app_model: App):
with sessionmaker(db.engine).begin() as session:

View File

@@ -1,23 +1,27 @@
import json
from datetime import datetime
from typing import Any
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class MCPServerCreatePayload(BaseModel):
@@ -32,8 +36,33 @@ class MCPServerUpdatePayload(BaseModel):
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class AppMCPServerResponse(ResponseModel):
id: str
name: str
server_code: str
description: str
status: str
parameters: dict[str, Any] | list[Any] | str
created_at: int | None = None
updated_at: int | None = None
@field_validator("parameters", mode="before")
@classmethod
def _parse_json_string(cls, value: Any) -> Any:
if isinstance(value, str):
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
return value
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
@console_ns.route("/apps/<uuid:app_id>/server")
@@ -41,27 +70,27 @@ class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
@login_required
@account_initialization_required
@setup_required
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
return server
if server is None:
return {}
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
@console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@login_required
@setup_required
@marshal_with(app_server_model)
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
@@ -82,20 +111,19 @@ class AppMCPServerController(Resource):
)
db.session.add(server)
db.session.commit()
return server
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
@@ -118,7 +146,7 @@ class AppMCPServerController(Resource):
except ValueError:
raise ValueError("Invalid status")
db.session.commit()
return server
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
@@ -126,13 +154,12 @@ class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
@@ -145,4 +172,4 @@ class AppMCPServerRefreshController(Resource):
raise NotFound()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return server
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")

View File

@@ -1,11 +1,12 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@@ -15,13 +16,11 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
@@ -49,13 +48,26 @@ class AppSiteUpdatePayload(BaseModel):
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
class AppSiteResponse(ResponseModel):
app_id: str
access_token: str | None = Field(default=None, validation_alias="code")
code: str | None = None
title: str
icon: str | None = None
icon_background: str | None = None
description: str | None = None
default_language: str
customize_domain: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
customize_token_strategy: str
prompt_public: bool
show_workflow_steps: bool
use_icon_as_answer_icon: bool
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse)
@console_ns.route("/apps/<uuid:app_id>/site")
@@ -64,7 +76,7 @@ class AppSite(Resource):
@console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found")
@setup_required
@@ -72,7 +84,6 @@ class AppSite(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
@@ -106,7 +117,7 @@ class AppSite(Resource):
site.updated_at = naive_utc_now()
db.session.commit()
return site
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
@@ -114,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
@console_ns.doc("reset_app_site_access_token")
@console_ns.doc(description="Reset access token for application site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Access token reset successfully", app_site_model)
@console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__])
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
@console_ns.response(404, "App or site not found")
@setup_required
@@ -122,7 +133,6 @@ class AppSiteAccessTokenReset(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
@@ -135,4 +145,4 @@ class AppSiteAccessTokenReset(Resource):
site.updated_at = naive_utc_now()
db.session.commit()
return site
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")

View File

@@ -1,8 +1,9 @@
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
@@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone
from models import AccountStatus
from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
@@ -39,8 +38,16 @@ class ActivatePayload(BaseModel):
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class ActivationCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether token is valid")
data: dict | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
@console_ns.route("/activate/check")
@@ -51,13 +58,7 @@ class ActivateCheckApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
"data": fields.Raw(description="Activation data if valid"),
},
),
console_ns.models[ActivationCheckResponse.__name__],
)
def get(self):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@@ -95,12 +96,7 @@ class ActivateApi(Resource):
@console_ns.response(
200,
"Account activated successfully",
console_ns.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
},
),
console_ns.models[ActivationResponse.__name__],
)
@console_ns.response(400, "Already activated or invalid token")
def post(self):

View File

@@ -11,10 +11,7 @@ import services
from configs import dify_config
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
api_key_list_model,
)
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import (
@@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return {"items": keys}
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@marshal_with(api_key_item_model)
def post(self):
_, current_tenant_id = current_account_with_tenant()
@@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 200
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")

View File

@@ -4,7 +4,6 @@ from argparse import ArgumentTypeError
from collections.abc import Sequence
from contextlib import ExitStack
from typing import Any, Literal, cast
from uuid import UUID
import sqlalchemy as sa
from flask import request, send_file
@@ -16,6 +15,7 @@ from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from core.errors.error import (
@@ -71,9 +71,6 @@ from ..wraps import (
logger = logging.getLogger(__name__)
# NOTE: Keep constants near the top of the module for discoverability.
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = get_or_create_model("Dataset", dataset_fields)
@@ -110,12 +107,6 @@ class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")

View File

@@ -10,6 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
@@ -82,14 +83,6 @@ class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]

View File

@@ -1,9 +1,9 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
@@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)

View File

@@ -94,10 +94,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
def plugin_data[**P, R](
view: Callable[P, R] | None = None,
*,
payload_type: type[BaseModel],
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
@@ -116,7 +115,4 @@ def plugin_data[**P, R](
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
return decorator

View File

@@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import Annotation, AnnotationList
from models.model import App
from services.annotation_service import AppAnnotationService
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
InsertAnnotationArgs,
UpdateAnnotationArgs,
)
class AnnotationCreatePayload(BaseModel):
@@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
enable_args: EnableAnnotationArgs = {
"score_threshold": payload.score_threshold,
"embedding_provider_name": payload.embedding_provider_name,
"embedding_model_name": payload.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200
@@ -135,8 +145,9 @@ class AnnotationListApi(Resource):
@validate_app_token
def post(self, app_model: App):
"""Create a new annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
@@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")

View File

@@ -3,10 +3,10 @@ import logging
from flask import request
from flask_restx import Resource
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@@ -86,13 +86,6 @@ class AudioApi(Resource):
raise InternalServerError()
class TextToAudioPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(service_api_ns, TextToAudioPayload)

View File

@@ -10,6 +10,7 @@ from sqlalchemy import desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
@@ -100,15 +101,6 @@ class DocumentListQuery(BaseModel):
status: str | None = Field(default=None, description="Document status filter")
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading uploaded documents as a ZIP archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
register_enum_models(service_api_ns, RetrievalMethod)
register_schema_models(

View File

@@ -2,9 +2,9 @@ from typing import Literal
from flask_login import current_user
from flask_restx import marshal
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
@@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_model(service_api_ns, MetadataUpdatePayload)
register_schema_models(
service_api_ns,

View File

@@ -8,6 +8,7 @@ from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
@@ -69,20 +70,12 @@ class SegmentUpdatePayload(BaseModel):
segment: SegmentUpdateArgs
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class ChildChunkUpdatePayload(BaseModel):
content: str
register_schema_models(
service_api_ns,
SegmentCreatePayload,

View File

@@ -92,7 +92,7 @@ class HumanInputFormApi(Resource):
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
service = HumanInputService(db.engine)
# TODO(QuantumGhost): forbid submision for form tokens
# TODO(QuantumGhost): forbid submission for form tokens
# that are only for console.
form = service.get_form_by_token(form_token)

View File

@@ -3,10 +3,10 @@ from typing import Literal
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@@ -25,7 +25,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
from libs.helper import uuid_value
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@@ -41,19 +40,6 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: str = Field(description="Conversation UUID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",

View File

@@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
@staticmethod
def _get_completion_start_time(
start_time: datetime | None, time_to_first_token: float | int | None
) -> datetime | None:
"""Convert a relative TTFT value in seconds into Langfuse's absolute completion start time."""
if start_time is None or time_to_first_token is None:
return None
try:
ttft_seconds = float(time_to_first_token)
except (TypeError, ValueError):
return None
if ttft_seconds < 0:
return None
return start_time + timedelta(seconds=ttft_seconds)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
@@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance):
total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
completion_start_time = None
try:
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
usage_data = process_data.get("usage")
if not isinstance(usage_data, dict):
usage_data = outputs.get("usage")
if not isinstance(usage_data, dict):
usage_data = {}
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
completion_start_time = self._get_completion_start_time(
created_at, usage_data.get("time_to_first_token")
)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
@@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_id=trace_id,
model=process_data.get("model_name"),
start_time=created_at,
completion_start_time=completion_start_time,
end_time=finished_at,
input=inputs,
output=outputs,
@@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance):
unit=UnitEnum.TOKENS,
totalCost=message_data.total_price,
)
completion_start_time = self._get_completion_start_time(
trace_info.start_time,
trace_info.gen_ai_server_time_to_first_token,
)
langfuse_generation_data = LangfuseGeneration(
name="llm",
trace_id=trace_id,
start_time=trace_info.start_time,
completion_start_time=completion_start_time,
end_time=trace_info.end_time,
model=message_data.model_id,
input=trace_info.inputs,

View File

@@ -0,0 +1,87 @@
"""Vector store backend discovery.
Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package
declares third-party dependencies and registers ``importlib`` entry points in group
``dify.vector_backends`` (see each package's ``pyproject.toml``).
Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol
remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``).
Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a
distribution; entry points take precedence when both exist.
After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``.
"""
from __future__ import annotations
import importlib
import logging
from importlib.metadata import entry_points
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
logger = logging.getLogger(__name__)
_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {}
# module_path:class_name — optional fallback when no distribution registers the backend.
_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {}
def clear_vector_factory_cache() -> None:
"""Drop lazily loaded factories (for tests or plugin reload)."""
_VECTOR_FACTORY_CACHE.clear()
def _vector_backend_entry_points():
return entry_points().select(group="dify.vector_backends")
def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None:
for ep in _vector_backend_entry_points():
if ep.name != vector_type:
continue
try:
loaded = ep.load()
except Exception:
logger.exception("Failed to load vector backend entry point %s", ep.name)
raise
return loaded # type: ignore[return-value]
return None
def _unsupported(vector_type: str) -> ValueError:
installed = sorted(ep.name for ep in _vector_backend_entry_points())
available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed."
return ValueError(
f"Vector store {vector_type!r} is not supported.{available_msg} "
"Install a plugin (uv sync --group vdb-all, or vdb-<backend> per api/pyproject.toml), "
"or register a dify.vector_backends entry point."
)
def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]:
target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type)
if not target:
raise _unsupported(vector_type)
module_path, _, attr = target.partition(":")
module = importlib.import_module(module_path)
return getattr(module, attr) # type: ignore[no-any-return]
def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]:
"""Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value."""
if vector_type in _VECTOR_FACTORY_CACHE:
return _VECTOR_FACTORY_CACHE[vector_type]
plugin_cls = _load_plugin_factory(vector_type)
if plugin_cls is not None:
_VECTOR_FACTORY_CACHE[vector_type] = plugin_cls
return plugin_cls
cls = _load_builtin_factory(vector_type)
_VECTOR_FACTORY_CACHE[vector_type] = cls
return cls

View File

@@ -9,6 +9,7 @@ from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
@@ -41,7 +42,23 @@ class AbstractVectorFactory(ABC):
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
# `is_summary` and `original_chunk_id` are stored on summary vectors
# by `SummaryIndexService` and read back by `RetrievalService` to
# route summary hits through their original parent chunks. They
# must be listed here so vector backends that use this list as an
# explicit return-properties projection (notably Weaviate) actually
# return those fields; without them, summary hits silently
# collapse into `is_summary = False` branches and the summary
# retrieval path is a no-op. See #34884.
attributes = [
"doc_id",
"dataset_id",
"document_id",
"doc_hash",
"doc_type",
"is_summary",
"original_chunk_id",
]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
@@ -69,137 +86,7 @@ class Vector:
@staticmethod
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
match vector_type:
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.ALIBABACLOUD_MYSQL:
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVectorFactory,
)
return AlibabaCloudMySQLVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.VASTBASE:
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
return VastbaseVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.COUCHBASE:
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
return CouchbaseVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
return BaiduVectorFactory
case VectorType.VIKINGDB:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
return VikingDBVectorFactory
case VectorType.UPSTASH:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.LINDORM:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case VectorType.OPENGAUSS:
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
return OpenGaussFactory
case VectorType.TABLESTORE:
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
return TableStoreVectorFactory
case VectorType.HUAWEI_CLOUD:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory
case VectorType.MATRIXONE:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case VectorType.CLICKZETTA:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case VectorType.IRIS:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case VectorType.HOLOGRES:
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
return HologresVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
return get_vector_factory_class(vector_type)
def create(self, texts: list | None = None, **kwargs):
if texts:

View File

@@ -1,10 +1,19 @@
"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``).
:class:`AbstractVectorTest` and helper functions live here so package tests can import
``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the
``tests.*`` package.
The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is
auto-discovered by pytest for all package tests.
"""
import uuid
from unittest.mock import MagicMock
import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions import ext_redis
from models.dataset import Dataset
@@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document:
return doc
@pytest.fixture
def setup_mock_redis():
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractVectorTest:
vector: BaseVector
def __init__(self):
self.vector = None
self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
self.example_doc_id = str(uuid.uuid4())

View File

@@ -244,7 +244,7 @@ class DatasetDocumentStore:
return document_segment
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
if multimodel_documents:
if multimodel_documents and self._document_id is not None:
for multimodel_document in multimodel_documents:
binding = SegmentAttachmentBinding(
tenant_id=self._dataset.tenant_id,

View File

@@ -4,7 +4,12 @@ from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEve
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
from core.rag.entities.retrieval_settings import (
KeywordSetting,
RerankingModelConfig,
VectorSetting,
WeightedScoreConfig,
)
__all__ = [
"Condition",
@@ -19,6 +24,7 @@ __all__ = [
"MetadataFilteringCondition",
"ParentMode",
"PreProcessingRule",
"RerankingModelConfig",
"RetrievalSourceMetadata",
"Rule",
"Segmentation",

View File

@@ -1,4 +1,27 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field
class RerankingModelConfig(BaseModel):
"""
Canonical reranking model configuration.
Accepts both naming conventions:
- reranking_provider_name / reranking_model_name (services layer)
- provider / model (workflow layer via validation_alias)
"""
model_config = ConfigDict(populate_by_name=True)
reranking_provider_name: str = Field(validation_alias="provider")
reranking_model_name: str = Field(validation_alias="model")
@property
def provider(self) -> str:
return self.reranking_provider_name
@property
def model(self) -> str:
return self.reranking_model_name
class VectorSetting(BaseModel):

View File

@@ -4,21 +4,12 @@ from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import NodeType
from pydantic import BaseModel
from core.rag.entities.retrieval_settings import WeightedScoreConfig
from core.rag.entities import RerankingModelConfig, WeightedScoreConfig
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str
reranking_model_name: str
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.

View File

@@ -5,20 +5,11 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.nodes.llm.entities import ModelConfig, VisionConfig
from pydantic import BaseModel, Field
from core.rag.entities import Condition, MetadataFilteringCondition, WeightedScoreConfig
from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig
__all__ = ["Condition"]
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
provider: str
model: str
class MultipleRetrievalConfig(BaseModel):
"""
Multiple Retrieval Config.

View File

@@ -17,7 +17,6 @@ def http_status_message(code):
def register_external_error_handlers(api: Api):
@api.errorhandler(HTTPException)
def handle_http_exception(e: HTTPException):
got_request_exception.send(current_app, exception=e)
@@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api):
headers["Set-Cookie"] = build_force_logout_cookie_headers()
return data, status_code, headers
_ = handle_http_exception
@api.errorhandler(ValueError)
def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e)
status_code = 400
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code
_ = handle_value_error
@api.errorhandler(AppInvokeQuotaExceededError)
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
got_request_exception.send(current_app, exception=e)
status_code = 429
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
return data, status_code
_ = handle_quota_exceeded
@api.errorhandler(Exception)
def handle_general_exception(e: Exception):
got_request_exception.send(current_app, exception=e)
@@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api):
return data, status_code
_ = handle_general_exception
api.errorhandler(HTTPException)(handle_http_exception)
api.errorhandler(ValueError)(handle_value_error)
api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded)
api.errorhandler(Exception)(handle_general_exception)
class ExternalApi(Api):

View File

@@ -1688,7 +1688,7 @@ class PipelineRecommendedPlugin(TypeBase):
)
class SegmentAttachmentBinding(Base):
class SegmentAttachmentBinding(TypeBase):
__tablename__ = "segment_attachment_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
@@ -1701,13 +1701,17 @@ class SegmentAttachmentBinding(Base):
),
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
class DocumentSegmentSummary(Base):

View File

@@ -838,7 +838,7 @@ class AppModelConfig(TypeBase):
return self
class RecommendedApp(Base): # bug
class RecommendedApp(TypeBase):
__tablename__ = "recommended_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
@@ -846,20 +846,37 @@ class RecommendedApp(Base): # bug
sa.Index("recommended_app_is_listed_idx", "is_listed", "language"),
)
id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False)
description = mapped_column(sa.JSON, nullable=False)
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
description: Mapped[Any] = mapped_column(sa.JSON, nullable=False)
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
category: Mapped[str] = mapped_column(String(255), nullable=False)
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
language: Mapped[str] = mapped_column(
String(255),
nullable=False,
server_default=sa.text("'en-US'"),
default="en-US",
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
@property
@@ -1061,7 +1078,7 @@ class Conversation(Base):
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
message_annotations = db.relationship(
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
lambda: MessageAnnotation, backref="conversation", lazy="select", passive_deletes="all"
)
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@@ -1820,7 +1837,7 @@ class MessageFile(TypeBase):
)
class MessageAnnotation(Base):
class MessageAnnotation(TypeBase):
__tablename__ = "message_annotations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@@ -1829,17 +1846,25 @@ class MessageAnnotation(Base):
sa.Index("message_annotation_message_idx", "message_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[str | None] = mapped_column(StringUUID)
question: Mapped[str] = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
@property

View File

@@ -61,7 +61,7 @@ from factories import variable_factory
from libs import helper
from .account import Account
from .base import Base, DefaultFieldsMixin, TypeBase
from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
@@ -671,6 +671,29 @@ class Workflow(Base): # bug
return str(d)
class WorkflowRunDict(TypedDict):
id: str
tenant_id: str
app_id: str
workflow_id: str
type: WorkflowType
triggered_from: WorkflowRunTriggeredFrom
version: str
graph: Mapping[str, Any]
inputs: Mapping[str, Any]
status: WorkflowExecutionStatus
outputs: Mapping[str, Any]
error: str | None
elapsed_time: float
total_tokens: int
total_steps: int
created_by_role: CreatorUserRole
created_by: str
created_at: datetime
finished_at: datetime | None
exceptions_count: int
class WorkflowRun(Base):
"""
Workflow Run
@@ -742,8 +765,8 @@ class WorkflowRun(Base):
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
lambda: WorkflowPause,
primaryjoin=lambda: WorkflowRun.id == orm.foreign(WorkflowPause.workflow_run_id),
uselist=False,
# require explicit preloading.
lazy="raise",
@@ -790,29 +813,29 @@ class WorkflowRun(Base):
def workflow(self):
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
def to_dict(self):
return {
"id": self.id,
"tenant_id": self.tenant_id,
"app_id": self.app_id,
"workflow_id": self.workflow_id,
"type": self.type,
"triggered_from": self.triggered_from,
"version": self.version,
"graph": self.graph_dict,
"inputs": self.inputs_dict,
"status": self.status,
"outputs": self.outputs_dict,
"error": self.error,
"elapsed_time": self.elapsed_time,
"total_tokens": self.total_tokens,
"total_steps": self.total_steps,
"created_by_role": self.created_by_role,
"created_by": self.created_by,
"created_at": self.created_at,
"finished_at": self.finished_at,
"exceptions_count": self.exceptions_count,
}
def to_dict(self) -> WorkflowRunDict:
return WorkflowRunDict(
id=self.id,
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
type=self.type,
triggered_from=self.triggered_from,
version=self.version,
graph=self.graph_dict,
inputs=self.inputs_dict,
status=self.status,
outputs=self.outputs_dict,
error=self.error,
elapsed_time=self.elapsed_time,
total_tokens=self.total_tokens,
total_steps=self.total_steps,
created_by_role=self.created_by_role,
created_by=self.created_by,
created_at=self.created_at,
finished_at=self.finished_at,
exceptions_count=self.exceptions_count,
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
@@ -1196,6 +1219,18 @@ class WorkflowAppLogCreatedFrom(StrEnum):
raise ValueError(f"invalid workflow app log created from value {value}")
class WorkflowAppLogDict(TypedDict):
id: str
tenant_id: str
app_id: str
workflow_id: str
workflow_run_id: str
created_from: WorkflowAppLogCreatedFrom
created_by_role: CreatorUserRole
created_by: str
created_at: datetime
class WorkflowAppLog(TypeBase):
"""
Workflow App execution log, excluding workflow debugging records.
@@ -1273,8 +1308,8 @@ class WorkflowAppLog(TypeBase):
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
def to_dict(self):
return {
def to_dict(self) -> WorkflowAppLogDict:
result: WorkflowAppLogDict = {
"id": self.id,
"tenant_id": self.tenant_id,
"app_id": self.app_id,
@@ -1285,6 +1320,7 @@ class WorkflowAppLog(TypeBase):
"created_by": self.created_by,
"created_at": self.created_at,
}
return result
class WorkflowArchiveLog(TypeBase):
@@ -1941,7 +1977,7 @@ def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE
class WorkflowPause(DefaultFieldsMixin, Base):
class WorkflowPause(DefaultFieldsDCMixin, TypeBase):
"""
WorkflowPause records the paused state and related metadata for a specific workflow run.
@@ -1980,6 +2016,11 @@ class WorkflowPause(DefaultFieldsMixin, Base):
nullable=False,
)
# state_object_key stores the object key referencing the serialized runtime state
# of the `GraphEngine`. This object captures the complete execution context of the
# workflow at the moment it was paused, enabling accurate resumption.
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# `resumed_at` records the timestamp when the suspended workflow was resumed.
# It is set to `NULL` if the workflow has not been resumed.
#
@@ -1988,25 +2029,23 @@ class WorkflowPause(DefaultFieldsMixin, Base):
resumed_at: Mapped[datetime | None] = mapped_column(
sa.DateTime,
nullable=True,
default=None,
)
# state_object_key stores the object key referencing the serialized runtime state
# of the `GraphEngine`. This object captures the complete execution context of the
# workflow at the moment it was paused, enabling accurate resumption.
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
# Relationship to WorkflowRun (uses lambda to resolve across Base/TypeBase registries)
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
lambda: WorkflowRun,
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
uselist=False,
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
primaryjoin=lambda: WorkflowPause.workflow_run_id == WorkflowRun.id,
back_populates="pause",
init=False,
)
class WorkflowPauseReason(DefaultFieldsMixin, Base):
class WorkflowPauseReason(DefaultFieldsDCMixin, TypeBase):
__tablename__ = "workflow_pause_reasons"
# `pause_id` represents the identifier of the pause,
@@ -2049,16 +2088,20 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
lazy="raise",
uselist=False,
primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id",
init=False,
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
def from_entity(cls, *, pause_id: str, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
pause_id=pause_id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=pause_reason.form_id,
node_id=pause_reason.node_id,
)
elif isinstance(pause_reason, SchedulingPause):
return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="")
return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message)
else:
raise AssertionError(f"Unknown pause reason type: {pause_reason}")

12
api/providers/README.md Normal file
View File

@@ -0,0 +1,12 @@
# Providers
This directory holds **optional workspace packages** that plug into Difys API core. Providers are responsible for implementing the interfaces and registering themselves to the API core. Provider mechanism allows building the software with selected set of providers so as to enhance the security and flexibility of distributions.
## Developing Providers
- [VDB Providers](vdb/README.md)
## Tests
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).

View File

@@ -0,0 +1,58 @@
# VDB providers
This directory contains all VDB providers.
## Architecture
1. **Core** (`api/core/rag/datasource/vdb/`) defines the contracts and loads plugins.
2. **Each provider** (`api/providers/vdb/<backend>/`) implements those contracts and registers an entry point.
3. At runtime, **`importlib.metadata.entry_points`** resolves the backend name (e.g. `pgvector`) to a factory class. The registry caches loaded classes (see `vector_backend_registry.py`).
### Interfaces
| Piece | Role |
|--------|----------|
| `AbstractVectorFactory` | You subclass this. Implement `init_vector(dataset, attributes, embeddings) -> BaseVector`. Optionally use `gen_index_struct_dict()` for new datasets. |
| `BaseVector` | Your store class subclasses this: `create`, `add_texts`, `search_by_vector`, `delete`, etc. |
| `VectorType` | `StrEnum` of supported backend **string ids**. Add a member when you introduce a new backend that should be selectable like existing ones. |
| Discovery | Loads `dify.vector_backends` entry points and caches `get_vector_factory_class(vector_type)`. |
The high-level caller is `Vector` in `vector_factory.py`: it reads the configured or dataset-specific vector type, calls `get_vector_factory_class`, instantiates the factory, and uses the returned `BaseVector` implementation.
### Entry point name must match the vector type string
Entry points are registered under the group **`dify.vector_backends`**. The **entry point name** (left-hand side) must be exactly the string used as `vector_type` everywhere else—typically the **`VectorType` enum value** (e.g. `PGVECTOR = "pgvector"` → entry point name `pgvector`; `TIDB_ON_QDRANT = "tidb_on_qdrant"``tidb_on_qdrant`).
In `pyproject.toml`:
```toml
[project.entry-points."dify.vector_backends"]
pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory"
```
The value is **`module:attribute`**: a importable module path and the class implementing `AbstractVectorFactory`.
### How registration works
1. On first use, `get_vector_factory_class(vector_type)` looks up `vector_type` in a process cache.
2. If missing, it scans **`entry_points().select(group="dify.vector_backends")`** for an entry whose **`name` equals `vector_type`**.
3. It loads that entry (`ep.load()`), which must return the **factory class** (not an instance).
4. There is an optional internal map `_BUILTIN_VECTOR_FACTORY_TARGETS` for non-distribution builtins; **normal VDB plugins use entry points only**.
After you change a providers `pyproject.toml` (entry points or dependencies), run **`uv sync`** in `api/` so the installed environments dist-info matches the project metadata.
### Package layout (VDB)
Each backend usually follows:
- `api/providers/vdb/<backend>/pyproject.toml` — project name `dify-vdb-<backend>`, dependencies, entry points.
- `api/providers/vdb/<backend>/src/dify_vdb_<python_package>/` — implementation (e.g. `PGVector`, `PGVectorFactory`).
See `vdb/pgvector/` as a reference implementation.
### Wiring a new backend into the API workspace
The API uses a **uv workspace** (`api/pyproject.toml`):
1. **`[tool.uv.workspace]`** — `members = ["providers/vdb/*"]` already includes every subdirectory under `vdb/`; new folders there are workspace members.
2. **`[tool.uv.sources]`** — add a line for your package: `dify-vdb-mine = { workspace = true }`.
3. **`[project.optional-dependencies]`** — add a group such as `vdb-mine = ["dify-vdb-mine"]`, and list `dify-vdb-mine` under `vdb-all` if it should install with the default bundle.

View File

@@ -0,0 +1,22 @@
from unittest.mock import MagicMock
import pytest
from extensions import ext_redis
@pytest.fixture(autouse=True)
def _init_mock_redis():
"""Ensure redis_client has a backing client so __getattr__ never raises."""
if ext_redis.redis_client._client is None:
ext_redis.redis_client.initialize(MagicMock())
@pytest.fixture
def setup_mock_redis(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(ext_redis.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(ext_redis.redis_client, "set", MagicMock(return_value=None))
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
monkeypatch.setattr(ext_redis.redis_client, "lock", mock_redis_lock)

View File

@@ -0,0 +1,13 @@
[project]
name = "dify-vdb-alibabacloud-mysql"
version = "0.0.1"
dependencies = [
"mysql-connector-python>=9.3.0",
]
description = "Dify vector store backend (dify-vdb-alibabacloud-mysql)."
[project.entry-points."dify.vector_backends"]
alibabacloud_mysql = "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector:AlibabaCloudMySQLVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,10 +1,9 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
import pytest
import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
def test_validate_distance_function_accepts_supported_values():

View File

@@ -3,11 +3,11 @@ import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVector,
AlibabaCloudMySQLVectorConfig,
)
from core.rag.models.document import Document
try:
@@ -49,9 +49,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
# Sample embeddings
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_init(self, mock_pool_class):
"""Test AlibabaCloudMySQLVector initialization."""
# Mock the connection pool
@@ -76,10 +74,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert alibabacloud_mysql_vector.distance_function == "cosine"
assert alibabacloud_mysql_vector.pool is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_collection(self, mock_redis, mock_pool_class):
"""Test collection creation."""
# Mock Redis operations
@@ -110,9 +106,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
mock_redis.set.assert_called_once()
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_vector_support_check_success(self, mock_pool_class):
"""Test successful vector support check."""
# Mock the connection pool
@@ -129,9 +123,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert vector_store is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_vector_support_check_failure(self, mock_pool_class):
"""Test vector support check failure."""
# Mock the connection pool
@@ -149,9 +141,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_vector_support_check_function_error(self, mock_pool_class):
"""Test vector support check with function not found error."""
# Mock the connection pool
@@ -170,10 +160,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_documents(self, mock_redis, mock_pool_class):
"""Test creating documents with embeddings."""
# Setup mocks
@@ -186,9 +174,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "doc1" in result
assert "doc2" in result
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_add_texts(self, mock_pool_class):
"""Test adding texts to the vector store."""
# Mock the connection pool
@@ -207,9 +193,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(result) == 2
mock_cursor.executemany.assert_called_once()
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_text_exists(self, mock_pool_class):
"""Test checking if text exists."""
# Mock the connection pool
@@ -236,9 +220,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "SELECT id FROM" in last_call[0][0]
assert last_call[0][1] == ("doc1",)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_text_not_exists(self, mock_pool_class):
"""Test checking if text does not exist."""
# Mock the connection pool
@@ -260,9 +242,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert not exists
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_get_by_ids(self, mock_pool_class):
"""Test getting documents by IDs."""
# Mock the connection pool
@@ -288,9 +268,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert docs[0].page_content == "Test document 1"
assert docs[1].page_content == "Test document 2"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_get_by_ids_empty_list(self, mock_pool_class):
"""Test getting documents with empty ID list."""
# Mock the connection pool
@@ -308,9 +286,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_by_ids(self, mock_pool_class):
"""Test deleting documents by IDs."""
# Mock the connection pool
@@ -334,9 +310,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "DELETE FROM" in delete_call[0][0]
assert delete_call[0][1] == ["doc1", "doc2"]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_by_ids_empty_list(self, mock_pool_class):
"""Test deleting with empty ID list."""
# Mock the connection pool
@@ -357,9 +331,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
"""Test deleting when table doesn't exist."""
# Mock the connection pool
@@ -384,9 +356,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
# Should not raise an exception
vector_store.delete_by_ids(["doc1"])
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_by_metadata_field(self, mock_pool_class):
"""Test deleting documents by metadata field."""
# Mock the connection pool
@@ -410,9 +380,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
assert delete_call[0][1] == ("$.document_id", "dataset1")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_cosine(self, mock_pool_class):
"""Test vector search with cosine distance."""
# Mock the connection pool
@@ -437,9 +405,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
assert docs[0].metadata["distance"] == 0.1
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_euclidean(self, mock_pool_class):
"""Test vector search with euclidean distance."""
config = AlibabaCloudMySQLVectorConfig(
@@ -472,9 +438,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 1
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_with_filter(self, mock_pool_class):
"""Test vector search with document ID filter."""
# Mock the connection pool
@@ -499,9 +463,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
search_call = search_calls[0]
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
"""Test vector search with score threshold."""
# Mock the connection pool
@@ -536,9 +498,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 1
assert docs[0].page_content == "High similarity document"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
"""Test vector search with invalid top_k."""
# Mock the connection pool
@@ -560,9 +520,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_full_text(self, mock_pool_class):
"""Test full-text search."""
# Mock the connection pool
@@ -591,9 +549,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert docs[0].page_content == "This document contains machine learning content"
assert docs[0].metadata["score"] == 1.5
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_full_text_with_filter(self, mock_pool_class):
"""Test full-text search with document ID filter."""
# Mock the connection pool
@@ -617,9 +573,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
search_call = search_calls[0]
assert "AND JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
"""Test full-text search with invalid top_k."""
# Mock the connection pool
@@ -640,9 +594,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_collection(self, mock_pool_class):
"""Test deleting the entire collection."""
# Mock the connection pool
@@ -665,9 +617,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
drop_call = drop_calls[0]
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_unsupported_distance_function(self, mock_pool_class):
"""Test that Pydantic validation rejects unsupported distance functions."""
# Test that creating config with unsupported distance function raises ValidationError

View File

@@ -0,0 +1,15 @@
[project]
name = "dify-vdb-analyticdb"
version = "0.0.1"
dependencies = [
"alibabacloud_gpdb20160503~=5.2.0",
"alibabacloud_tea_openapi~=0.4.3",
"clickhouse-connect~=0.15.0",
]
description = "Dify vector store backend (dify-vdb-analyticdb)."
[project.entry-points."dify.vector_backends"]
analyticdb = "dify_vdb_analyticdb.analyticdb_vector:AnalyticdbVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -2,16 +2,16 @@ import json
from typing import Any
from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from models.dataset import Dataset

View File

@@ -1,9 +1,8 @@
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest
class AnalyticdbVectorTest(AbstractVectorTest):

View File

@@ -1,12 +1,12 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import dify_vdb_analyticdb.analyticdb_vector as analyticdb_module
import pytest
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from core.rag.models.document import Document

View File

@@ -4,13 +4,13 @@ import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import dify_vdb_analyticdb.analyticdb_vector_openapi as openapi_module
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.models.document import Document

View File

@@ -2,14 +2,14 @@ from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
import dify_vdb_analyticdb.analyticdb_vector_sql as sql_module
import psycopg2.errors
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import (
from dify_vdb_analyticdb.analyticdb_vector_sql import (
AnalyticdbVectorBySql,
AnalyticdbVectorBySqlConfig,
)
from core.rag.models.document import Document

View File

@@ -0,0 +1,13 @@
[project]
name = "dify-vdb-baidu"
version = "0.0.1"
dependencies = [
"pymochow==2.4.0",
]
description = "Dify vector store backend (dify-vdb-baidu)."
[project.entry-points."dify.vector_backends"]
baidu = "dify_vdb_baidu.baidu_vector:BaiduVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,10 +1,6 @@
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector
pytest_plugins = (
"tests.integration_tests.vdb.test_vector_store",
"tests.integration_tests.vdb.__mock.baiduvectordb",
)
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
class BaiduVectorTest(AbstractVectorTest):

View File

@@ -124,7 +124,7 @@ def _build_fake_pymochow_modules():
def baidu_module(monkeypatch):
for name, module in _build_fake_pymochow_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.baidu.baidu_vector as module
import dify_vdb_baidu.baidu_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,13 @@
[project]
name = "dify-vdb-chroma"
version = "0.0.1"
dependencies = [
"chromadb==0.5.20",
]
description = "Dify vector store backend (dify-vdb-chroma)."
[project.entry-points."dify.vector_backends"]
chroma = "dify_vdb_chroma.chroma_vector:ChromaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,13 +1,11 @@
import chromadb
from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
get_example_text,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class ChromaVectorTest(AbstractVectorTest):
def __init__(self):

View File

@@ -47,7 +47,7 @@ def _build_fake_chroma_modules():
def chroma_module(monkeypatch):
fake_chroma = _build_fake_chroma_modules()
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
import core.rag.datasource.vdb.chroma.chroma_vector as module
import dify_vdb_chroma.chroma_vector as module
return importlib.reload(module)

View File

@@ -198,4 +198,4 @@ Clickzetta supports advanced full-text search with multiple analyzers:
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-clickzetta"
version = "0.0.1"
dependencies = [
"clickzetta-connector-python>=0.8.102",
]
description = "Dify vector store backend (dify-vdb-clickzetta)."
[project.entry-points."dify.vector_backends"]
clickzetta = "dify_vdb_clickzetta.clickzetta_vector:ClickzettaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -2,10 +2,10 @@ import contextlib
import os
import pytest
from dify_vdb_clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class TestClickzettaVector(AbstractVectorTest):
@@ -14,9 +14,8 @@ class TestClickzettaVector(AbstractVectorTest):
"""
@pytest.fixture
def vector_store(self):
def vector_store(self, setup_mock_redis):
"""Create a Clickzetta vector store instance for testing."""
# Skip test if Clickzetta credentials are not configured
if not os.getenv("CLICKZETTA_USERNAME"):
pytest.skip("CLICKZETTA_USERNAME is not configured")
if not os.getenv("CLICKZETTA_PASSWORD"):
@@ -32,21 +31,19 @@ class TestClickzettaVector(AbstractVectorTest):
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
batch_size=10, # Small batch size for testing
batch_size=10,
enable_inverted_index=True,
analyzer_type="chinese",
analyzer_mode="smart",
vector_distance_function="cosine_distance",
)
with setup_mock_redis():
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
yield vector
yield vector
# Cleanup: delete the test collection
with contextlib.suppress(Exception):
vector.delete()
with contextlib.suppress(Exception):
vector.delete()
def test_clickzetta_vector_basic_operations(self, vector_store):
"""Test basic CRUD operations on Clickzetta vector store."""

View File

@@ -3,16 +3,19 @@
Test Clickzetta integration in Docker environment
"""
import logging
import os
import time
import httpx
from clickzetta import connect
logger = logging.getLogger(__name__)
def test_clickzetta_connection():
"""Test direct connection to Clickzetta"""
print("=== Testing direct Clickzetta connection ===")
logger.info("=== Testing direct Clickzetta connection ===")
try:
conn = connect(
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
@@ -25,100 +28,93 @@ def test_clickzetta_connection():
)
with conn.cursor() as cursor:
# Test basic connectivity
cursor.execute("SELECT 1 as test")
result = cursor.fetchone()
print(f"✓ Connection test: {result}")
logger.info("✓ Connection test: %s", result)
# Check if our test table exists
cursor.execute("SHOW TABLES IN dify")
tables = cursor.fetchall()
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
logger.info("✓ Existing tables: %s", [t[1] for t in tables if t[0] == "dify"])
# Check if test collection exists
test_collection = "collection_test_dataset"
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
cursor.execute(f"DESCRIBE dify.{test_collection}")
columns = cursor.fetchall()
print(f"✓ Table structure for {test_collection}:")
logger.info("✓ Table structure for %s:", test_collection)
for col in columns:
print(f" - {col[0]}: {col[1]}")
logger.info(" - %s: %s", col[0], col[1])
# Check for indexes
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
indexes = cursor.fetchall()
print(f"✓ Indexes on {test_collection}:")
logger.info("✓ Indexes on %s:", test_collection)
for idx in indexes:
print(f" - {idx}")
logger.info(" - %s", idx)
return True
except Exception as e:
print(f"✗ Connection test failed: {e}")
except Exception:
logger.exception("✗ Connection test failed")
return False
def test_dify_api():
"""Test Dify API with Clickzetta backend"""
print("\n=== Testing Dify API ===")
logger.info("\n=== Testing Dify API ===")
base_url = "http://localhost:5001"
# Wait for API to be ready
max_retries = 30
for i in range(max_retries):
try:
response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200:
print("✓ Dify API is ready")
logger.info("✓ Dify API is ready")
break
except:
if i == max_retries - 1:
print("✗ Dify API is not responding")
logger.exception("✗ Dify API is not responding")
return False
time.sleep(2)
# Check vector store configuration
try:
# This is a simplified check - in production, you'd use proper auth
print("✓ Dify is configured to use Clickzetta as vector store")
logger.info("✓ Dify is configured to use Clickzetta as vector store")
return True
except Exception as e:
print(f"✗ API test failed: {e}")
except Exception:
logger.exception("✗ API test failed")
return False
def verify_table_structure():
"""Verify the table structure meets Dify requirements"""
print("\n=== Verifying Table Structure ===")
logger.info("\n=== Verifying Table Structure ===")
expected_columns = {
"id": "VARCHAR",
"page_content": "VARCHAR",
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
"metadata": "VARCHAR",
"vector": "ARRAY<FLOAT>",
}
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
print("✓ Expected table structure:")
logger.info("✓ Expected table structure:")
for col, dtype in expected_columns.items():
print(f" - {col}: {dtype}")
logger.info(" - %s: %s", col, dtype)
print("\n✓ Required metadata fields:")
logger.info("\n✓ Required metadata fields:")
for field in expected_metadata_fields:
print(f" - {field}")
logger.info(" - %s", field)
print("\n✓ Index requirements:")
print(" - Vector index (HNSW) on 'vector' column")
print(" - Full-text index on 'page_content' (optional)")
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
print(" - Functional index on metadata->>'$.document_id' (recommended)")
logger.info("\n✓ Index requirements:")
logger.info(" - Vector index (HNSW) on 'vector' column")
logger.info(" - Full-text index on 'page_content' (optional)")
logger.info(" - Functional index on metadata->>'$.doc_id' (recommended)")
logger.info(" - Functional index on metadata->>'$.document_id' (recommended)")
return True
def main():
"""Run all tests"""
print("Starting Clickzetta integration tests for Dify Docker\n")
logger.info("Starting Clickzetta integration tests for Dify Docker\n")
tests = [
("Direct Clickzetta Connection", test_clickzetta_connection),
@@ -131,33 +127,34 @@ def main():
try:
success = test_func()
results.append((test_name, success))
except Exception as e:
print(f"\n{test_name} crashed: {e}")
except Exception:
logger.exception("\n%s crashed", test_name)
results.append((test_name, False))
# Summary
print("\n" + "=" * 50)
print("Test Summary:")
print("=" * 50)
logger.info("\n%s", "=" * 50)
logger.info("Test Summary:")
logger.info("=" * 50)
passed = sum(1 for _, success in results if success)
total = len(results)
for test_name, success in results:
status = "✅ PASSED" if success else "❌ FAILED"
print(f"{test_name}: {status}")
logger.info("%s: %s", test_name, status)
print(f"\nTotal: {passed}/{total} tests passed")
logger.info("\nTotal: %s/%s tests passed", passed, total)
if passed == total:
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
print("\nNext steps:")
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
print("2. Access Dify at http://localhost:3000")
print("3. Create a dataset and test vector storage with Clickzetta")
logger.info("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
logger.info("\nNext steps:")
logger.info(
"1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d"
)
logger.info("2. Access Dify at http://localhost:3000")
logger.info("3. Create a dataset and test vector storage with Clickzetta")
return 0
else:
print("\n⚠️ Some tests failed. Please check the errors above.")
logger.error("\n⚠️ Some tests failed. Please check the errors above.")
return 1

View File

@@ -47,7 +47,7 @@ def _build_fake_clickzetta_module():
@pytest.fixture
def clickzetta_module(monkeypatch):
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module
import dify_vdb_clickzetta.clickzetta_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-couchbase"
version = "0.0.1"
dependencies = [
"couchbase~=4.6.0",
]
description = "Dify vector store backend (dify-vdb-couchbase)."
[project.entry-points."dify.vector_backends"]
couchbase = "dify_vdb_couchbase.couchbase_vector:CouchbaseVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,12 +1,14 @@
import logging
import subprocess
import time
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
from tests.integration_tests.vdb.test_vector_store import (
from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
logger = logging.getLogger(__name__)
def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
@@ -16,10 +18,10 @@ def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True
)
if result.stdout.strip() == "healthy":
print(f"{service_name} is healthy!")
logger.info("%s is healthy!", service_name)
return True
else:
print(f"Waiting for {service_name} to be healthy...")
logger.info("Waiting for %s to be healthy...", service_name)
time.sleep(10)
raise TimeoutError(f"{service_name} did not become healthy in time")

View File

@@ -154,7 +154,7 @@ def couchbase_module(monkeypatch):
for name, module in _build_fake_couchbase_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.couchbase.couchbase_vector as module
import dify_vdb_couchbase.couchbase_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,15 @@
[project]
name = "dify-vdb-elasticsearch"
version = "0.0.1"
dependencies = [
"elasticsearch==8.14.0",
]
description = "Dify vector store backend (dify-vdb-elasticsearch)."
[project.entry-points."dify.vector_backends"]
elasticsearch = "dify_vdb_elasticsearch.elasticsearch_vector:ElasticSearchVectorFactory"
elasticsearch-ja = "dify_vdb_elasticsearch.elasticsearch_ja_vector:ElasticSearchJaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -4,14 +4,14 @@ from typing import Any
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from dify_vdb_elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@@ -1,10 +1,9 @@
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
from tests.integration_tests.vdb.test_vector_store import (
from dify_vdb_elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class ElasticSearchVectorTest(AbstractVectorTest):
def __init__(self):

View File

@@ -32,8 +32,8 @@ def elasticsearch_ja_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module
import dify_vdb_elasticsearch.elasticsearch_ja_vector as ja_module
import dify_vdb_elasticsearch.elasticsearch_vector as base_module
importlib.reload(base_module)
return importlib.reload(ja_module)

View File

@@ -42,7 +42,7 @@ def elasticsearch_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module
import dify_vdb_elasticsearch.elasticsearch_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-hologres"
version = "0.0.1"
dependencies = [
"holo-search-sdk>=0.4.2",
]
description = "Dify vector store backend (dify-vdb-hologres)."
[project.entry-points."dify.vector_backends"]
hologres = "dify_vdb_hologres.hologres_vector:HologresVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any
from typing import Any, cast
import holo_search_sdk as holo # type: ignore
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
@@ -351,9 +351,9 @@ class HologresVectorFactory(AbstractVectorFactory):
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
schema_name=dify_config.HOLOGRES_SCHEMA,
tokenizer=dify_config.HOLOGRES_TOKENIZER,
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD,
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE,
tokenizer=cast(TokenizerType, dify_config.HOLOGRES_TOKENIZER),
distance_method=cast(DistanceType, dify_config.HOLOGRES_DISTANCE_METHOD),
base_quantization_type=cast(BaseQuantizationType, dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE),
max_degree=dify_config.HOLOGRES_MAX_DEGREE,
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
),

View File

@@ -7,13 +7,10 @@ import pytest
from _pytest.monkeypatch import MonkeyPatch
from psycopg import sql as psql
# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}}
_mock_tables: dict[str, dict[str, dict[str, Any]]] = {}
class MockSearchQuery:
"""Mock query builder for search_vector and search_text results."""
def __init__(self, table_name: str, search_type: str):
self._table_name = table_name
self._search_type = search_type
@@ -32,17 +29,13 @@ class MockSearchQuery:
return self
def _apply_filter(self, row: dict[str, Any]) -> bool:
"""Apply the filter SQL to check if a row matches."""
if self._filter_sql is None:
return True
# Extract literals (the document IDs) from the filter SQL
# Filter format: meta->>'document_id' IN ('doc1', 'doc2')
literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"]
if not literals:
return True
# Get the document_id from the row's meta field
meta = row.get("meta", "{}")
if isinstance(meta, str):
meta = json.loads(meta)
@@ -54,22 +47,17 @@ class MockSearchQuery:
data = _mock_tables.get(self._table_name, {})
results = []
for row in list(data.values())[: self._limit_val]:
# Apply filter if present
if not self._apply_filter(row):
continue
if self._search_type == "vector":
# row format expected by _process_vector_results: (distance, id, text, meta)
results.append((0.1, row["id"], row["text"], row["meta"]))
else:
# row format expected by _process_full_text_results: (id, text, meta, embedding, score)
results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9))
return results
class MockTable:
"""Mock table object returned by client.open_table()."""
def __init__(self, table_name: str):
self._table_name = table_name
@@ -97,7 +85,6 @@ class MockTable:
def _extract_sql_template(query) -> str:
"""Extract the SQL template string from a psycopg Composed object."""
if isinstance(query, psql.Composed):
for part in query:
if isinstance(part, psql.SQL):
@@ -108,7 +95,6 @@ def _extract_sql_template(query) -> str:
def _extract_identifiers_and_literals(query) -> list[Any]:
"""Extract Identifier and Literal values from a psycopg Composed object."""
values: list[Any] = []
if isinstance(query, psql.Composed):
for part in query:
@@ -117,7 +103,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
elif isinstance(part, psql.Literal):
values.append(("literal", part._obj))
elif isinstance(part, psql.Composed):
# Handles SQL(...).join(...) for IN clauses
for sub in part:
if isinstance(sub, psql.Literal):
values.append(("literal", sub._obj))
@@ -125,8 +110,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
class MockHologresClient:
"""Mock holo_search_sdk client that stores data in memory."""
def connect(self):
pass
@@ -141,21 +124,18 @@ class MockHologresClient:
params = _extract_identifiers_and_literals(query)
if "CREATE TABLE" in template.upper():
# Extract table name from first identifier
table_name = next((v for t, v in params if t == "ident"), "unknown")
if table_name not in _mock_tables:
_mock_tables[table_name] = {}
return None
if "SELECT 1" in template:
# text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1
table_name = next((v for t, v in params if t == "ident"), "")
doc_id = next((v for t, v in params if t == "literal"), "")
data = _mock_tables.get(table_name, {})
return [(1,)] if doc_id in data else []
if "SELECT id" in template:
# get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value}
table_name = next((v for t, v in params if t == "ident"), "")
literals = [v for t, v in params if t == "literal"]
key = literals[0] if len(literals) > 0 else ""
@@ -166,12 +146,10 @@ class MockHologresClient:
if "DELETE" in template.upper():
table_name = next((v for t, v in params if t == "ident"), "")
if "id IN" in template:
# delete_by_ids
ids_to_delete = [v for t, v in params if t == "literal"]
for did in ids_to_delete:
_mock_tables.get(table_name, {}).pop(did, None)
elif "meta->>" in template:
# delete_by_metadata_field
literals = [v for t, v in params if t == "literal"]
key = literals[0] if len(literals) > 0 else ""
value = literals[1] if len(literals) > 1 else ""
@@ -190,7 +168,6 @@ class MockHologresClient:
def mock_connect(**kwargs):
"""Replacement for holo_search_sdk.connect() that returns a mock client."""
return MockHologresClient()

View File

@@ -2,16 +2,11 @@ import os
import uuid
from typing import cast
from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVector, HologresVectorConfig
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
pytest_plugins = (
"tests.integration_tests.vdb.test_vector_store",
"tests.integration_tests.vdb.__mock.hologres",
)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"

View File

@@ -42,7 +42,7 @@ def hologres_module(monkeypatch):
for name, module in _build_fake_hologres_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.hologres.hologres_vector as module
import dify_vdb_hologres.hologres_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-huawei-cloud"
version = "0.0.1"
dependencies = [
"elasticsearch==8.14.0",
]
description = "Dify vector store backend (dify-vdb-huawei-cloud)."
[project.entry-points."dify.vector_backends"]
huawei_cloud = "dify_vdb_huawei_cloud.huawei_cloud_vector:HuaweiCloudVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

Some files were not shown because too many files have changed in this diff Show More