Compare commits

..

251 Commits

Author SHA1 Message Date
-LAN-
4b0d2bf57f chore: update build-push.yml to remove unnecessary tags 2024-09-04 20:20:39 +08:00
-LAN-
94432a0a69 chore: update package versions to 0.8.0-beta1 (#7979) 2024-09-04 19:56:33 +08:00
takatost
7e30487f8b feat: update dsl version 2024-09-04 19:41:43 +08:00
Yi
46634638e7 fix: refine the "isInIteration" for workflow 2024-09-04 17:45:07 +08:00
StyleZhang
44038b9628 fix: iteration copy 2024-09-04 17:26:43 +08:00
StyleZhang
c625f4282f Merge branch 'main' into feat/workflow-parallel-support 2024-09-04 15:22:57 +08:00
StyleZhang
4f5dc82459 fix 2024-09-04 15:03:35 +08:00
Yi
5cb018e15d update the method to check if a node is in iteration 2024-09-04 14:59:30 +08:00
KVOJJJin
14af87527f Feat:remove estimation of embedding cost (#7950)
Co-authored-by: jyong <718720800@qq.com>
2024-09-04 14:41:47 +08:00
zhuhao
83e84865be feat: add health check for pg and redis in docker-compose.middleware.yaml (#7961) (#7962) 2024-09-04 14:25:46 +08:00
StyleZhang
4962b2c460 check node edge 2024-09-04 13:27:17 +08:00
Joel
c2a3c5a748 fix: get commit sha failed in translate action (#7959) 2024-09-04 13:13:21 +08:00
呆萌闷油瓶
83494cb4f5 fix:empty voice occurs when xinference CosyVoice tts model (#7958) 2024-09-04 13:04:31 +08:00
Vico Chu
0bc19c3fbf Feat: update app published time after clicking publish button (#7801) 2024-09-04 13:03:06 +08:00
Sumkor
571415d1a4 fix: split text keep separator (#7930) 2024-09-04 12:59:10 +08:00
Yuki Oshima
7b2cf8215f chore: fix inverted index japanese translation (#7957) 2024-09-04 12:44:59 +08:00
Joe
fee4d3f6ca feat: ops trace add llm model (#7306) 2024-09-04 10:39:00 +08:00
Yi
cd42dbdae8 update the log for iteration nodes 2024-09-04 10:27:40 +08:00
takatost
161cc0cda9 Revert "fix: an issue of keyword search feature in application log list" (#7949) 2024-09-04 10:00:55 +08:00
Nam Vu
71bff9fcf3 chore: #7943 i18n (#7948) 2024-09-04 09:42:25 +08:00
陳鈞
80d14c9b22 fix(api): Code-Based Extension cause error on position map sorting (#7934)
Signed-off-by: 陳鈞 <jim60105@gmail.com>
2024-09-04 08:41:12 +08:00
crazywoola
c5bdf08558 Chore/add roadmap (#7943) 2024-09-04 08:33:02 +08:00
crazywoola
596f160a1e Chore/add default step 1x url (#7933) 2024-09-04 08:32:22 +08:00
Jyong
d8b6c053a2 fix rerank model value is empty string (#7937) 2024-09-03 21:25:21 +08:00
Nam Vu
4b262cae58 chore: #7603 i18n (#7931) 2024-09-03 19:19:52 +08:00
takatost
78fa1f6868 fix(workflow): detached session issues 2024-09-03 18:23:37 +08:00
Jyong
1a5116cba0 Fix/segment create with api (#7928) 2024-09-03 18:14:47 +08:00
Jyong
01581dd35f improve the notion table extract (#7925) 2024-09-03 17:52:07 +08:00
Yi
6bee121ebe update log in web app 2024-09-03 17:27:32 +08:00
Joel
7fdd964379 fix: frontend handle sometimes server not generate the wrong follow up data struct (#7916) 2024-09-03 14:09:46 +08:00
takatost
36d95e49b0 fix(iteration): iterator_length not correct 2024-09-03 12:01:56 +08:00
Yi
3431b19f9a update styling and iteration log 2024-09-03 11:46:04 +08:00
Yi
b28c7b1cda Merge branch 'feat/workflow-parallel-support' of github.com:langgenius/dify into feat/workflow-parallel-support 2024-09-03 10:35:15 +08:00
Yi
83343eefe6 update parallel log 2024-09-03 10:34:50 +08:00
Joel
0cfcc97e9d feat: support auto generate i18n translate (#6964)
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-03 10:17:05 +08:00
-LAN-
8986be0aab chore: Update versions to 0.7.3 (#7895) 2024-09-03 09:49:32 +08:00
-LAN-
f76bbbf5e6 chore(Dockerfile): Bump expat to 2.6.2-2 (#7904) 2024-09-03 09:48:30 +08:00
kurokobo
fe217da05c fix: correct typo in the setting screen (#7897) 2024-09-02 22:49:56 +08:00
takatost
d92966545b fix: migration 2024-09-02 22:41:08 +08:00
takatost
f71c51cb9a Merge branch 'refs/heads/main' into feat/workflow-parallel-support 2024-09-02 22:37:23 +08:00
takatost
955884b87e chore(workflow): max thread submit count 2024-09-02 20:20:32 +08:00
kurokobo
80aa7c4019 feat: allow users to use the app icon as the answer icon (#7888)
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-02 20:00:41 +08:00
Jyong
6f33351eb3 ignore linked images when image id is none (#7890) 2024-09-02 19:37:05 +08:00
Alex
35f13c7327 Add Russian language (#7860)
Co-authored-by: d8rt8v <alex@ydertev.ru>
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-02 19:09:41 +08:00
takatost
5ca9df65de feat(workflow): add thread pool 2024-09-02 19:02:45 +08:00
takatost
166365a502 feat(workflow): add thread pool 2024-09-02 19:02:21 +08:00
StyleZhang
70aced0100 fix 2024-09-02 18:38:21 +08:00
takatost
35d9c59a29 Merge remote-tracking branch 'origin/feat/workflow-parallel-support' into feat/workflow-parallel-support 2024-09-02 17:56:16 +08:00
takatost
bbc922dffa merge main 2024-09-02 17:55:28 +08:00
StyleZhang
7035f64ce3 fix: next step 2024-09-02 17:52:54 +08:00
takatost
81d09d471c Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/workflow/app_generator.py
2024-09-02 17:52:51 +08:00
takatost
5bda3a384a fix(workflow): bugs 2024-09-02 17:49:51 +08:00
Fei He
a8b9e01b3e fix: fixed typo on loading reranking_mode (#7887) 2024-09-02 16:18:47 +08:00
StyleZhang
43240fcd41 fix 2024-09-02 14:50:05 +08:00
Joshua
7193e189f3 Add perplexity search as a new tool (#7861) 2024-09-02 14:48:13 +08:00
orangeclk
3f2a806abe fix: glm models prices and max_tokens correction (#7882) 2024-09-02 14:29:09 +08:00
takatost
52b4623131 fix(workflow): fix merge branch node id err 2024-09-02 13:56:07 +08:00
takatost
0dabf799c0 fix(workflow): fix merge branch node id err 2024-09-02 11:52:14 +08:00
legao
5e4907e940 fix: layout shift on app card hover (#7872) 2024-09-02 11:05:54 +08:00
omr
bf63c5d1e3 fix typo: langauge -> language (#7875) 2024-09-02 08:41:45 +08:00
Yi
29b1ce781d fix: node end status 2024-09-01 22:00:54 +08:00
Seayon
78989e9049 Add ALIYUN_OSS_PATH configuration for Aliyun OSS (#7864)
Co-authored-by: seayon <zhaoxuyang@shouqianba.com>
2024-09-01 21:30:17 +08:00
Hirotaka Miyagi
1510bdbcf6 refactor: Remove typecasting by any (#7862) 2024-09-01 14:58:12 +08:00
Hirotaka Miyagi
024d688b77 fix(RetrievalConfig): Fix score threshold assignment for zero value (#7865) 2024-09-01 14:57:50 +08:00
zhujinle
ef82a29e23 fix: crash when ECharts accesses undefined objects (#7853) 2024-09-01 14:52:27 +08:00
sino
1f56a20b62 feat: support auth by api key for ark provider (#7845) 2024-08-31 10:56:32 +08:00
Yi
71a7d890cc fix styling 2024-08-30 23:31:05 +08:00
Yi
ee1587c939 fix: make the End node always nested in the root 2024-08-30 20:14:56 +08:00
Yi
d7c0ca852e feat: inner parallels will be added to its corresponding branch 2024-08-30 20:08:57 +08:00
takatost
162e9677c7 fix(workflow): missing parallel event in workflow app 2024-08-30 20:04:17 +08:00
Bowen Liang
0c2a62f847 fix: correct http timeout configs‘ default values and ignorance by HttpRequestNode (#7762) 2024-08-30 19:09:10 +08:00
takatost
77e62f7fee fix(workflow): run node in multi parallel bugs 2024-08-30 18:55:33 +08:00
Ethan
ea748b50f2 fix: an issue of keyword search feature in application log list (#7816) 2024-08-30 18:48:05 +08:00
Yi Xiao
62bfc4dba6 fix: tooltip size sets improperly (#7836) 2024-08-30 18:13:54 +08:00
Yi
e3295181d2 fix a typo 2024-08-30 18:01:13 +08:00
Yi
2b5b856126 solve the branch issue 2024-08-30 17:58:29 +08:00
Yi
e3ae529a55 update the onNodeFinished method for nodes being passed through more than once 2024-08-30 17:00:02 +08:00
Zhi
ceb2b150ff enhance: include workspace name in create-tenant command (#7834) 2024-08-30 15:53:50 +08:00
Yi
708256ef1d Merge branch 'feat/workflow-parallel-support' of github.com:langgenius/dify into feat/workflow-parallel-support 2024-08-30 15:23:21 +08:00
非法操作
dc015c380a feat: add zhipu glm_4_plus and glm_4v_plus model (#7824) 2024-08-30 15:08:31 +08:00
StyleZhang
7c9081a8fc fix 2024-08-30 13:44:01 +08:00
Benjamin
c9e0f0bf20 fix: correct typo in environment variable description (#7817) 2024-08-30 00:03:40 +08:00
YidaHu
bd6d4d0553 fix: filter out installed apps without an app (#7799) 2024-08-29 19:03:08 +08:00
hisir
f0273f00e1 Fixed when testing the openai compatible interface model, an error is reported when no object is returned (#7808) 2024-08-29 18:58:19 +08:00
Yi
1bde57e591 delete console logs 2024-08-29 17:54:26 +08:00
Yi
32a11cbb6a update the parallel workflow log for iteration and chatflow preview 2024-08-29 17:26:17 +08:00
Yi
3e257ae907 update the workflow parallel log 2024-08-29 16:38:51 +08:00
Yeuoly
962cdbbebd chore: add app generator overload (#7792) 2024-08-29 16:04:01 +08:00
NFish
2c51e3a327 fix: webapp sso setting may not the latest value when refresh (#7795) 2024-08-29 15:57:43 +08:00
Jyong
8e311cc45c fixed permission is None (#7788) 2024-08-29 12:46:42 +08:00
crazywoola
c441bea4d1 fix: datasets permission is missing (#7787) 2024-08-29 12:46:33 +08:00
StyleZhang
f43596f226 fix: parallel branch limit 2024-08-29 11:31:34 +08:00
NFish
ad30668eb6 Sync Input component from feat/attachments branch (#7782) 2024-08-29 11:23:16 +08:00
Huang YunKun
62f4801523 Update ssrf_proxy related doc link in docker-compose file (#7778) 2024-08-29 11:22:39 +08:00
kanoshiou
ec1408346e docs: navigate to open issues in contributing documents (#7781) 2024-08-29 11:18:49 +08:00
takatost
0e0a703496 chore: ignore openai error record in sentry (#7770) 2024-08-28 23:26:11 +08:00
takatost
ae22015fe7 fix(workflow): loop check 2024-08-28 21:47:47 +08:00
takatost
790dd3b22f fix(workflow): duplicate nodes in parallel 2024-08-28 19:01:45 +08:00
Garfield Dai
54b693d5b1 feat: update saas billing hint. (#7760) 2024-08-28 18:55:47 +08:00
takatost
5d34e080eb fix: migration 2024-08-28 18:02:49 +08:00
takatost
6b6750b9ad Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/services/app_generate_service.py
2024-08-28 18:01:57 +08:00
Bowen Liang
1262277714 chore: improve http executor configs (#7730) 2024-08-28 17:46:37 +08:00
takatost
74c8004944 fix(graph_engine): fix execute loops in parallel 2024-08-28 17:42:42 +08:00
StyleZhang
4418fa1d2b fix: bug 2024-08-28 17:40:50 +08:00
YidaHu
3a67fc6c5a feat: add support for array types in available variable list (#7715) 2024-08-28 17:30:13 +08:00
zhuhao
26abbe8e5b feat(Tools): add a tool to query the stock price from Alpha Vantage (#7019) (#7752) 2024-08-28 17:27:20 +08:00
Leheng Lu
5d0914daea fix: not able to pass array of string/number/object into variable aggregator groups (#7757) 2024-08-28 17:25:20 +08:00
Joel
7541a492b7 fix: crawl options max length can not set 0 (#7758)
Co-authored-by: Yi <yxiaoisme@gmail.com>
2024-08-28 17:16:07 +08:00
takatost
c2bb11405f fix(workflow): parallel not yield 2024-08-28 16:13:57 +08:00
StyleZhang
8ba5673606 feat: iteration support parallel 2024-08-28 16:00:17 +08:00
takatost
b0a81c654b fix(workflow): parallel execution after if-else that only one branch runs 2024-08-28 15:53:39 +08:00
crazywoola
3a071b8db9 fix: datasets permission is missing (#7751) 2024-08-28 15:36:11 +08:00
snickerjp
9342b4b951 Update package "libldap-2.5-0" for docker build. (#7726) 2024-08-28 14:44:05 +08:00
Vimpas
4682e0ac7c fix(storage): 🐛 HeadBucket Operation Permission (#7733)
Co-authored-by: 莫岳恒 <moyueheng@datagrand.com>
2024-08-28 13:57:45 +08:00
sino
7cfebffbb8 chore: update default endpoint for ark provider (#7741) 2024-08-28 13:56:50 +08:00
KVOJJJin
693fe912f2 Fix annotation reply settings (#7696) 2024-08-28 09:42:54 +08:00
kurokobo
bc3a8e0ca2 feat: store created_by and updated_by for apps, modelconfigs, and sites (#7613) 2024-08-28 08:47:30 +08:00
Jiakun Xu
e38334cfd2 fix: doc_language return null when document segment settings (#7719) 2024-08-28 08:45:51 +08:00
走在修行的大街上
92cab33b73 feat(Tools): add feishu document and message plugins (#6435)
Co-authored-by: 黎斌 <libin.23@bytedance.com>
2024-08-27 20:21:42 +08:00
Bowen Liang
3f467613fc feat: support configs for code execution request (#7704) 2024-08-27 19:38:33 +08:00
Bryan
205d33a813 Fix: read properties of undefined issue (#7708)
Co-authored-by: libing <libing@healink.cn>
2024-08-27 19:23:56 +08:00
crazywoola
da326baa5e fix: tongyi Error: 'NoneType' object is not subscriptable (#7705) 2024-08-27 16:56:06 +08:00
crazywoola
d9198b5646 feat: remove unused code (#7702) 2024-08-27 16:47:34 +08:00
takatost
cd52633b0e fix(graph_engine): parent_parallel_id missing 2024-08-27 16:45:14 +08:00
takatost
4256e9d47f chore(iteration): keep start_node_id using in parallel start nodes 2024-08-27 16:38:33 +08:00
Jyong
60001a62c4 fixed chunk_overlap is None (#7703) 2024-08-27 16:38:06 +08:00
sino
ee7d5e7206 feat: support Moonshot and GLM models tool call for volc ark provider (#7666) 2024-08-27 14:43:37 +08:00
StyleZhang
4e3dc36e37 fix: workflow run edge status 2024-08-27 14:39:56 +08:00
呆萌闷油瓶
2726fb3d5d feat:dailymessages (#7603) 2024-08-27 12:53:27 +08:00
kurokobo
d7aa4076c9 feat: display account name on the logs page for the apps (#7668)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-08-27 12:40:44 +08:00
Kenn
122ce41020 feat: rewrite Elasticsearch index and search code to achieve Elasticsearch vector and full-text search (#7641)
Co-authored-by: haokai <haokai@shuwen.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
Co-authored-by: wellCh4n <wellCh4n@foxmail.com>
2024-08-27 11:43:44 +08:00
Charlie.Wei
e7afee1176 Langfuse view button (#7684) 2024-08-27 11:25:56 +08:00
zxhlyh
88730906ec fix: empty knowledge add file (#7690) 2024-08-27 11:25:27 +08:00
Bowen Liang
a15080a1d7 bug: (#7586 followup) fix config of CODE_MAX_STRING_LENGTH (#7683) 2024-08-27 10:38:24 +08:00
Jyong
35431bce0d fix dataset_id and index_node_id idx missed in document_segments tabl… (#7681) 2024-08-27 10:25:24 +08:00
Hélio Lúcio
7b7576ad55 Add Azure AI Studio as provider (#7549)
Co-authored-by: Hélio Lúcio <canais.hlucio@voegol.com.br>
2024-08-27 09:52:59 +08:00
Qin Liu
162faee4f2 fix: set score_threshold to zero if it is None for MyScale vectordb (#7640)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-08-27 09:47:16 +08:00
takatost
b9f34f679f fix: iteration start node id 2024-08-26 22:00:17 +08:00
Zhi
b7ff98d7ff fix: Remove useless debug information. (#7647) 2024-08-26 20:40:26 +08:00
-LAN-
0474f0c906 chore: Update version to 0.7.2 (#7646) 2024-08-26 20:11:55 +08:00
Shota Totsuka
430e100142 refactor: Add @staticmethod decorator in api/core (#7652) 2024-08-26 19:45:03 +08:00
Jyong
1473083a41 catch openai rate limit error (#7658) 2024-08-26 19:36:44 +08:00
snickerjp
7cda73f192 Proposal to revise Japanese expressions (#7664) 2024-08-26 19:05:49 +08:00
代君
7c2bb31a55 [fix] openai's tool role dose not support name parameter. (#7659) 2024-08-26 18:52:34 +08:00
StyleZhang
9c8144e463 feat: parallel hover 2024-08-26 17:49:11 +08:00
非法操作
ba82023445 fix: support float type for tool parameter's default value (#7644) 2024-08-26 17:10:54 +08:00
takatost
76bb8d1c1a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/services/app_generate_service.py
#	api/services/workflow_service.py
2024-08-26 16:17:19 +08:00
StyleZhang
1016db160e feat: parallel hover 2024-08-26 16:09:22 +08:00
-LAN-
13be84e4d4 chore(api/controllers): Apply Ruff Formatter. (#7645) 2024-08-26 15:29:10 +08:00
Jyong
7ae728a9a3 fix nltk averaged_perceptron_tagger download and fix score limit is none (#7582)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-08-26 15:14:05 +08:00
pinsily
a7743a4f47 add:save_model_credentials error log (#7630) 2024-08-26 14:46:29 +08:00
Zhi
103ff28530 feat: speed up the Docker build for dify-api for Chinese developers. (#7626) 2024-08-26 14:45:28 +08:00
Zhi
8dfdb37de3 fix: use LOG_LEVEL for celery startup (#7628) 2024-08-26 14:44:58 +08:00
Bowen Liang
17fd773a30 chore(api/services): apply ruff reformatting (#7599)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-08-26 13:43:57 +08:00
Bowen Liang
979422cdc6 chore(api/tasks): apply ruff reformatting (#7594) 2024-08-26 13:38:37 +08:00
Yi Xiao
3be756eaed feat: tooltip (#7634) 2024-08-26 13:00:02 +08:00
crazywoola
1ba3d3acd6 feat: replace show/hide workflow_steps with switch (#7627) 2024-08-26 11:00:57 +08:00
takatost
6c61776ee1 fix test 2024-08-25 22:02:21 +08:00
NFish
23cedc3f1c Web app now supports SSO config (#7137) 2024-08-25 18:47:16 +08:00
Joe
741c548f3c feat: web sso (#7135) 2024-08-25 18:47:02 +08:00
非法操作
556f4ad5df feat: add siliconflow text2img tool (#7612) 2024-08-25 14:39:58 +08:00
Seayon
561a61e7fe Improve MIME type detection for image URLs (#6531)
Co-authored-by: seayon <zhaoxuyang@shouqianba.com>
2024-08-25 13:36:16 +08:00
Shota Totsuka
47919983bf fix: typo in comment (#7606) 2024-08-25 09:56:08 +08:00
sino
efc136cce5 feat: Introduce Ark SDK v3 and ensure compatibility with models of SDK v2 (#7579)
Co-authored-by: crazywoola <427733928@qq.com>
2024-08-24 19:29:45 +08:00
takatost
4771e85630 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/tests/integration_tests/workflow/nodes/test_code.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-24 17:26:44 +08:00
takatost
85d319719c fix end node bug 2024-08-24 17:17:18 +08:00
Bowen Liang
b035c02f78 chore(api/tests): apply ruff reformat #7590 (#7591)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-08-23 23:52:25 +08:00
Bowen Liang
2da63654e5 chore(api/configs): apply ruff reformat (#7590) 2024-08-23 23:46:01 +08:00
Bowen Liang
3ace01cfb3 chore: cleanup and rearrange unclassified configs into feature config groups (#7586) 2024-08-23 22:40:07 +08:00
Junyan Qin
e3d7c7c6f9 fix(onebot): use yarl to format url (#7589) 2024-08-23 22:22:42 +08:00
Junyan Qin
8807d880dc Feat: add OneBot protocol tool (#7583) 2024-08-23 19:16:30 +08:00
Jie.F
70d6ab0bf5 Update stable_diffusion.py (#7536) 2024-08-23 18:58:13 +08:00
Amos
e42848f4b7 Do not pass query parameter when the value is empty (#7585) 2024-08-23 18:50:38 +08:00
Yi Xiao
25386af41a fix: knowledge setting "knowledge name" input width (#7584) 2024-08-23 17:20:19 +08:00
张皮皮
f29685f8a1 fix score_threshold is none, return all top K documents (#7581) 2024-08-23 16:59:34 +08:00
噢哎哟喂
ad13011043 add JSON Mode support for moonshot models (#7568) 2024-08-23 16:24:45 +08:00
Charlie.Wei
df69ad9f0e Langfuse view button (#7578) 2024-08-23 16:23:26 +08:00
takatost
42899fb3be fix bug 2024-08-23 00:38:42 +08:00
takatost
5b22d8f8b2 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/question_classifier/question_classifier_node.py
2024-08-23 00:32:28 +08:00
takatost
fe2b300288 fix lint 2024-08-22 23:54:07 +08:00
takatost
ec4fc784f0 fix iteration start node 2024-08-22 23:53:44 +08:00
takatost
d6da7b0336 fix dialogue_count 2024-08-22 13:06:17 +08:00
takatost
92072e2ed7 fix: ruff issues 2024-08-21 17:26:51 +08:00
takatost
e34497ded1 fix: merge issues 2024-08-21 17:25:26 +08:00
takatost
35be41b337 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/start/start_node.py
#	api/core/workflow/nodes/variable_assigner/__init__.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-21 16:59:23 +08:00
takatost
412be6d014 fix bug 2024-08-21 16:43:00 +08:00
takatost
1d88b62e25 fix(workflow): fix node link to previous node issue 2024-08-20 23:28:11 +08:00
takatost
617ea4b3b8 fix(workflow): fix parallel bug 2024-08-20 22:16:41 +08:00
takatost
755a9658c7 fix(workflow): add parallel id into published events 2024-08-18 20:18:13 +08:00
takatost
5d7865737f fix(workflow): issues in workflow parallels 2024-08-16 22:47:58 +08:00
takatost
352c45c8a2 feat(workflow): integrate parallel into workflow apps 2024-08-16 21:33:09 +08:00
StyleZhang
1973f5003b feat: frontend support parallel 2024-08-16 16:55:08 +08:00
takatost
5b5e6e31bf fix: answer node unit tests 2024-08-16 01:44:00 +08:00
takatost
90221c0a90 fix: unit tests 2024-08-16 01:43:35 +08:00
takatost
91e51ce2b8 fix(workflow): issues by merging main branch 2024-08-16 01:36:19 +08:00
takatost
db9b0ee985 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/base_app_runner.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/node_entities.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/workflow_engine_manager.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-16 01:19:29 +08:00
takatost
c5192650fb fix: unit tests in workflow 2024-08-15 23:47:59 +08:00
takatost
702df31db7 fix(workflow): fix generate issues in workflow 2024-08-15 20:45:23 +08:00
takatost
1da5862a96 feat(workflow): fix iteration single debug 2024-08-15 03:12:49 +08:00
takatost
6f6b32e1ee feat(workflow): integrate workflow entry with workflow app 2024-08-14 19:22:15 +08:00
takatost
674af04c39 fix migration version depends 2024-08-13 17:15:21 +08:00
takatost
2980e31ddf fix issues when merging from main 2024-08-13 17:11:19 +08:00
takatost
14d020fffe Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/task_pipeline/workflow_cycle_manage.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/base_node.py
#	api/core/workflow/workflow_engine_manager.py
2024-08-13 17:05:39 +08:00
takatost
8401a11109 feat(workflow): integrate workflow entry with advanced chat app 2024-08-13 16:21:10 +08:00
takatost
8d27ec364f fix bug 2024-07-31 02:27:23 +08:00
takatost
c9bb366e1a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/iteration/iteration_node.py
#	api/core/workflow/workflow_engine_manager.py
2024-07-31 02:25:31 +08:00
takatost
917aacbf7f add chatflow app event convert 2024-07-31 02:21:35 +08:00
takatost
0818b7b078 remove iteration special logic 2024-07-26 21:27:01 +08:00
takatost
88dcd7b737 fix bug 2024-07-26 20:29:12 +08:00
takatost
63addf8c94 add parallel branch events 2024-07-26 20:27:17 +08:00
takatost
483f71f03c fix logging 2024-07-26 20:13:11 +08:00
takatost
beea1e1663 fix lint 2024-07-26 19:47:12 +08:00
takatost
38f8c45755 add events in interation node 2024-07-26 19:47:02 +08:00
takatost
a31feacf28 fix iteration 2024-07-26 02:43:40 +08:00
takatost
ae351bd40e add iteration support 2024-07-25 23:07:27 +08:00
takatost
df133168dd fix lint 2024-07-25 21:06:23 +08:00
takatost
7c67ba8991 remove threadpool 2024-07-25 21:05:53 +08:00
takatost
4097f7c069 add parallel branch output 2024-07-25 19:39:06 +08:00
takatost
f4eb7cd037 add end stream output test 2024-07-25 04:03:53 +08:00
takatost
833584ba76 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/workflow_entry.py
2024-07-24 23:43:14 +08:00
takatost
ec7760795f save 2024-07-24 00:24:24 +08:00
takatost
e9bfedab9b Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/entities/variable_pool.py
2024-07-23 17:28:57 +08:00
takatost
7303b53af1 fix bug 2024-07-23 16:18:52 +08:00
takatost
0fe516568a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/end/end_node.py
#	api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
#	api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
#	api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
2024-07-23 16:18:34 +08:00
takatost
2c695ded79 fix bugs 2024-07-23 00:10:23 +08:00
takatost
a603e01f5e fix bug 2024-07-22 19:57:32 +08:00
takatost
beaac5033a fix bug 2024-07-20 00:57:41 +08:00
takatost
dad1a967ee finished answer stream output 2024-07-20 00:49:46 +08:00
takatost
7ad77e9e77 fix test 2024-07-18 08:19:58 +08:00
takatost
f67a88f44d fix test 2024-07-17 21:17:04 +08:00
takatost
90e518b05b fix bugs 2024-07-17 16:54:49 +08:00
takatost
cc96acdae3 fix bugs 2024-07-17 11:26:33 +08:00
takatost
16e2d00157 optimize 2024-07-17 01:07:23 +08:00
takatost
4ef3d4e65c optimize 2024-07-17 01:02:40 +08:00
takatost
775e52db4d merge 2024-07-16 17:46:20 +08:00
takatost
00ec36d47c add graph engine test 2024-07-16 16:37:37 +08:00
takatost
00fb23d0c9 graph engine implement 2024-07-15 23:40:02 +08:00
takatost
821e09b259 add run logics 2024-07-12 19:33:47 +08:00
takatost
d77b689a99 completed parallel tests 2024-07-10 21:21:06 +08:00
takatost
0e885a3cae refactor runtime 2024-07-08 16:29:13 +08:00
takatost
1adaf42f9d refactor graph 2024-07-07 23:08:45 +08:00
takatost
fed068ac2e Merge branch 'refs/heads/main' into feat/workflow-parallel-support 2024-07-07 16:57:21 +08:00
takatost
03f56a05eb refactor graph 2024-07-06 03:18:02 +08:00
takatost
1b6cd975f3 completed graph init test 2024-07-04 15:40:20 +08:00
takatost
0f19b2a986 optimize graph 2024-07-02 21:53:41 +08:00
takatost
8375517ccd save 2024-06-29 15:44:52 +08:00
takatost
1d8ecac093 save 2024-06-27 05:30:38 +08:00
takatost
aaa98c76d5 optimize 2024-06-26 23:56:30 +08:00
takatost
216910a4a1 add runtime state of graph 2024-06-25 17:43:13 +08:00
takatost
fe27c97fd9 add runtime graph 2024-06-25 14:41:14 +08:00
takatost
8217c46116 add new graph structure 2024-06-24 23:34:42 +08:00
882 changed files with 32289 additions and 20432 deletions

View File

@@ -125,7 +125,6 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@@ -0,0 +1,54 @@
name: Check i18n Files and Create PR
on:
pull_request:
types: [closed]
branches: [main]
jobs:
check-and-update:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
- name: Check for file changes in i18n/en-US
id: check_files
run: |
recent_commit_sha=$(git rev-parse HEAD)
second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2
with:
node-version: 'lts/*'
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.
branch: chore/automated-i18n-updates

View File

@@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
## Before you jump in
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests:

View File

@@ -8,7 +8,7 @@
## 在开始之前
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
### 功能请求:

View File

@@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
## 飛び込む前に
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
### 機能リクエスト

View File

@@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [
## Trước khi bắt đầu
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
### Yêu cầu tính năng:

View File

@@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint
ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string

1
api/.idea/vcs.xml generated
View File

@@ -12,5 +12,6 @@
</component>
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>

View File

@@ -5,6 +5,10 @@ WORKDIR /app/api
# Install Poetry
ENV POETRY_VERSION=1.8.3
# if you located in China, you can use aliyun mirror to speed up
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry
@@ -16,6 +20,9 @@ ENV POETRY_REQUESTS_TIMEOUT=15
FROM base AS packages
# if you located in China, you can use aliyun mirror to speed up
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
@@ -43,10 +50,12 @@ WORKDIR /app/api
RUN apt-get update \
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# if you located in China, you can use aliyun mirror to speed up
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*
@@ -56,7 +65,7 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN python -c "import nltk; nltk.download('punkt')"
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
# Copy source code
COPY . /app/api/

View File

@@ -559,8 +559,9 @@ def add_qdrant_doc_id_index(field: str):
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None):
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
"""
Create tenant account
"""
@@ -580,13 +581,15 @@ def create_tenant(email: str, language: Optional[str] = None):
if language not in languages:
language = "en-US"
name = name.strip()
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
TenantService.create_owner_tenant_if_not_exist(account)
TenantService.create_owner_tenant_if_not_exist(account, name)
click.echo(
click.style(

View File

@@ -1,3 +1,3 @@
from .app_config import DifyConfig
dify_config = DifyConfig()
dify_config = DifyConfig()

View File

@@ -1,4 +1,3 @@
from pydantic import Field, computed_field
from pydantic_settings import SettingsConfigDict
from configs.deploy import DeploymentConfig
@@ -24,44 +23,16 @@ class DifyConfig(
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
):
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
model_config = SettingsConfigDict(
# read from dotenv format config file
env_file='.env',
env_file_encoding='utf-8',
env_file=".env",
env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes
extra='ignore',
extra="ignore",
)
CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_DEPTH: int = 5
CODE_MAX_PRECISION: int = 20
CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
SSRF_PROXY_HTTP_URL: str | None = None
SSRF_PROXY_HTTPS_URL: str | None = None
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')
# Before adding any config,
# please consider to arrange it in the proper config group of existed or added
# for better readability and maintainability.
# Thanks for your concentration and consideration.

View File

@@ -6,22 +6,28 @@ class DeploymentConfig(BaseSettings):
"""
Deployment configs
"""
APPLICATION_NAME: str = Field(
description='application name',
default='langgenius/dify',
description="application name",
default="langgenius/dify",
)
DEBUG: bool = Field(
description="whether to enable debug mode.",
default=False,
)
TESTING: bool = Field(
description='',
description="",
default=False,
)
EDITION: str = Field(
description='deployment edition',
default='SELF_HOSTED',
description="deployment edition",
default="SELF_HOSTED",
)
DEPLOY_ENV: str = Field(
description='deployment environment, default to PRODUCTION.',
default='PRODUCTION',
description="deployment environment, default to PRODUCTION.",
default="PRODUCTION",
)

View File

@@ -7,13 +7,14 @@ class EnterpriseFeatureConfig(BaseSettings):
Enterprise feature configs.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
"""
ENTERPRISE_ENABLED: bool = Field(
description='whether to enable enterprise features.'
'Before using, please contact business@dify.ai by email to inquire about licensing matters.',
description="whether to enable enterprise features."
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
default=False,
)
CAN_REPLACE_LOGO: bool = Field(
description='whether to allow replacing enterprise logo.',
description="whether to allow replacing enterprise logo.",
default=False,
)

View File

@@ -8,27 +8,28 @@ class NotionConfig(BaseSettings):
"""
Notion integration configs
"""
NOTION_CLIENT_ID: Optional[str] = Field(
description='Notion client ID',
description="Notion client ID",
default=None,
)
NOTION_CLIENT_SECRET: Optional[str] = Field(
description='Notion client secret key',
description="Notion client secret key",
default=None,
)
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
description='Notion integration type, default to None, available values: internal.',
description="Notion integration type, default to None, available values: internal.",
default=None,
)
NOTION_INTERNAL_SECRET: Optional[str] = Field(
description='Notion internal secret key',
description="Notion internal secret key",
default=None,
)
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
description='Notion integration token',
description="Notion integration token",
default=None,
)

View File

@@ -8,17 +8,18 @@ class SentryConfig(BaseSettings):
"""
Sentry configs
"""
SENTRY_DSN: Optional[str] = Field(
description='Sentry DSN',
description="Sentry DSN",
default=None,
)
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
description='Sentry trace sample rate',
description="Sentry trace sample rate",
default=1.0,
)
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
description='Sentry profiles sample rate',
description="Sentry profiles sample rate",
default=1.0,
)

View File

@@ -1,6 +1,6 @@
from typing import Optional
from typing import Annotated, Optional
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
@@ -10,16 +10,17 @@ class SecurityConfig(BaseSettings):
"""
Secret Key configs
"""
SECRET_KEY: Optional[str] = Field(
description='Your App secret key will be used for securely signing the session cookie'
'Make sure you are changing this key for your deployment with a strong key.'
'You can generate a strong key using `openssl rand -base64 42`.'
'Alternatively you can set it with `SECRET_KEY` environment variable.',
description="Your App secret key will be used for securely signing the session cookie"
"Make sure you are changing this key for your deployment with a strong key."
"You can generate a strong key using `openssl rand -base64 42`."
"Alternatively you can set it with `SECRET_KEY` environment variable.",
default=None,
)
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description='Expiry time in hours for reset token',
description="Expiry time in hours for reset token",
default=24,
)
@@ -28,12 +29,13 @@ class AppExecutionConfig(BaseSettings):
"""
App Execution configs
"""
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
description='execution timeout in seconds for app execution',
description="execution timeout in seconds for app execution",
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description='max active request per app, 0 means unlimited',
description="max active request per app, 0 means unlimited",
default=0,
)
@@ -42,14 +44,70 @@ class CodeExecutionSandboxConfig(BaseSettings):
"""
Code Execution Sandbox configs
"""
CODE_EXECUTION_ENDPOINT: str = Field(
description='endpoint URL of code execution servcie',
default='http://sandbox:8194',
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="endpoint URL of code execution servcie",
default="http://sandbox:8194",
)
CODE_EXECUTION_API_KEY: str = Field(
description='API key for code execution service',
default='dify-sandbox',
description="API key for code execution service",
default="dify-sandbox",
)
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
description="connect timeout in seconds for code execution request",
default=10.0,
)
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
description="read timeout in seconds for code execution request",
default=60.0,
)
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
description="write timeout in seconds for code execution request",
default=10.0,
)
CODE_MAX_NUMBER: PositiveInt = Field(
description="max depth for code execution",
default=9223372036854775807,
)
CODE_MIN_NUMBER: NegativeInt = Field(
description="",
default=-9223372036854775807,
)
CODE_MAX_DEPTH: PositiveInt = Field(
description="max depth for code execution",
default=5,
)
CODE_MAX_PRECISION: PositiveInt = Field(
description="max precision digits for float type in code execution",
default=20,
)
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="max string length for code execution",
default=80000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=30,
)
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=30,
)
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=1000,
)
@@ -57,28 +115,27 @@ class EndpointConfig(BaseSettings):
"""
Module URL configs
"""
CONSOLE_API_URL: str = Field(
description='The backend URL prefix of the console API.'
'used to concatenate the login authorization callback or notion integration callback.',
default='',
description="The backend URL prefix of the console API."
"used to concatenate the login authorization callback or notion integration callback.",
default="",
)
CONSOLE_WEB_URL: str = Field(
description='The front-end URL prefix of the console web.'
'used to concatenate some front-end addresses and for CORS configuration use.',
default='',
description="The front-end URL prefix of the console web."
"used to concatenate some front-end addresses and for CORS configuration use.",
default="",
)
SERVICE_API_URL: str = Field(
description='Service API Url prefix.'
'used to display Service API Base Url to the front-end.',
default='',
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
default="",
)
APP_WEB_URL: str = Field(
description='WebApp Url prefix.'
'used to display WebAPP API Base Url to the front-end.',
default='',
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
default="",
)
@@ -86,17 +143,18 @@ class FileAccessConfig(BaseSettings):
"""
File Access configs
"""
FILES_URL: str = Field(
description='File preview or download Url prefix.'
' used to display File preview or download Url to the front-end or as Multi-model inputs;'
'Url is signed and has expiration time.',
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
description="File preview or download Url prefix."
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
"Url is signed and has expiration time.",
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
alias_priority=1,
default='',
default="",
)
FILES_ACCESS_TIMEOUT: int = Field(
description='timeout in seconds for file accessing',
description="timeout in seconds for file accessing",
default=300,
)
@@ -105,23 +163,24 @@ class FileUploadConfig(BaseSettings):
"""
File Uploading configs
"""
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description='size limit in Megabytes for uploading files',
description="size limit in Megabytes for uploading files",
default=15,
)
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
description='batch size limit for uploading files',
description="batch size limit for uploading files",
default=5,
)
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description='image file size limit in Megabytes for uploading files',
description="image file size limit in Megabytes for uploading files",
default=10,
)
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description='', # todo: to be clarified
description="", # todo: to be clarified
default=20,
)
@@ -130,45 +189,79 @@ class HttpConfig(BaseSettings):
"""
HTTP configs
"""
API_COMPRESSION_ENABLED: bool = Field(
description='whether to enable HTTP response compression of gzip',
description="whether to enable HTTP response compression of gzip",
default=False,
)
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
description='',
validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
default='',
description="",
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
default="",
)
@computed_field
@property
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
description='',
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
default='*',
description="",
validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"),
default="*",
)
@computed_field
@property
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
] = 10
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
] = 60
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
] = 20
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="",
default=10 * 1024 * 1024,
)
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
description="",
default=1 * 1024 * 1024,
)
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
description="HTTP URL for SSRF proxy",
default=None,
)
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
description="HTTPS URL for SSRF proxy",
default=None,
)
class InnerAPIConfig(BaseSettings):
"""
Inner API configs
"""
INNER_API: bool = Field(
description='whether to enable the inner API',
description="whether to enable the inner API",
default=False,
)
INNER_API_KEY: Optional[str] = Field(
description='The inner API key is used to authenticate the inner API',
description="The inner API key is used to authenticate the inner API",
default=None,
)
@@ -179,28 +272,27 @@ class LoggingConfig(BaseSettings):
"""
LOG_LEVEL: str = Field(
description='Log output level, default to INFO.'
'It is recommended to set it to ERROR for production.',
default='INFO',
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
default="INFO",
)
LOG_FILE: Optional[str] = Field(
description='logging output file path',
description="logging output file path",
default=None,
)
LOG_FORMAT: str = Field(
description='log format',
default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
description="log format",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
)
LOG_DATEFORMAT: Optional[str] = Field(
description='log date format',
description="log date format",
default=None,
)
LOG_TZ: Optional[str] = Field(
description='specify log timezone, eg: America/New_York',
description="specify log timezone, eg: America/New_York",
default=None,
)
@@ -209,8 +301,9 @@ class ModelLoadBalanceConfig(BaseSettings):
"""
Model load balance configs
"""
MODEL_LB_ENABLED: bool = Field(
description='whether to enable model load balancing',
description="whether to enable model load balancing",
default=False,
)
@@ -219,8 +312,9 @@ class BillingConfig(BaseSettings):
"""
Platform Billing Configurations
"""
BILLING_ENABLED: bool = Field(
description='whether to enable billing',
description="whether to enable billing",
default=False,
)
@@ -229,9 +323,10 @@ class UpdateConfig(BaseSettings):
"""
Update configs
"""
CHECK_UPDATE_URL: str = Field(
description='url for checking updates',
default='https://updates.dify.ai',
description="url for checking updates",
default="https://updates.dify.ai",
)
@@ -241,47 +336,53 @@ class WorkflowConfig(BaseSettings):
"""
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
description='max execution steps in single workflow execution',
description="max execution steps in single workflow execution",
default=500,
)
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
description='max execution time in seconds in single workflow execution',
description="max execution time in seconds in single workflow execution",
default=1200,
)
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
description='max depth of calling in single workflow execution',
description="max depth of calling in single workflow execution",
default=5,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="The maximum size in bytes of a variable. default to 5KB.",
default=5 * 1024,
)
class OAuthConfig(BaseSettings):
"""
oauth configs
"""
OAUTH_REDIRECT_PATH: str = Field(
description='redirect path for OAuth',
default='/console/api/oauth/authorize',
description="redirect path for OAuth",
default="/console/api/oauth/authorize",
)
GITHUB_CLIENT_ID: Optional[str] = Field(
description='GitHub client id for OAuth',
description="GitHub client id for OAuth",
default=None,
)
GITHUB_CLIENT_SECRET: Optional[str] = Field(
description='GitHub client secret key for OAuth',
description="GitHub client secret key for OAuth",
default=None,
)
GOOGLE_CLIENT_ID: Optional[str] = Field(
description='Google client id for OAuth',
description="Google client id for OAuth",
default=None,
)
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
description='Google client secret key for OAuth',
description="Google client secret key for OAuth",
default=None,
)
@@ -291,9 +392,8 @@ class ModerationConfig(BaseSettings):
Moderation in app configs.
"""
# todo: to be clarified in usage and unit
OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field(
description='buffer size for moderation',
MODERATION_BUFFER_SIZE: PositiveInt = Field(
description="buffer size for moderation",
default=300,
)
@@ -304,7 +404,7 @@ class ToolConfig(BaseSettings):
"""
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
description='max age in seconds for tool icon caching',
description="max age in seconds for tool icon caching",
default=3600,
)
@@ -315,52 +415,52 @@ class MailConfig(BaseSettings):
"""
MAIL_TYPE: Optional[str] = Field(
description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.',
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
default=None,
)
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
description='default email address for sending from ',
description="default email address for sending from ",
default=None,
)
RESEND_API_KEY: Optional[str] = Field(
description='API key for Resend',
description="API key for Resend",
default=None,
)
RESEND_API_URL: Optional[str] = Field(
description='API URL for Resend',
description="API URL for Resend",
default=None,
)
SMTP_SERVER: Optional[str] = Field(
description='smtp server host',
description="smtp server host",
default=None,
)
SMTP_PORT: Optional[int] = Field(
description='smtp server port',
description="smtp server port",
default=465,
)
SMTP_USERNAME: Optional[str] = Field(
description='smtp server username',
description="smtp server username",
default=None,
)
SMTP_PASSWORD: Optional[str] = Field(
description='smtp server password',
description="smtp server password",
default=None,
)
SMTP_USE_TLS: bool = Field(
description='whether to use TLS connection to smtp server',
description="whether to use TLS connection to smtp server",
default=False,
)
SMTP_OPPORTUNISTIC_TLS: bool = Field(
description='whether to use opportunistic TLS connection to smtp server',
description="whether to use opportunistic TLS connection to smtp server",
default=False,
)
@@ -371,22 +471,22 @@ class RagEtlConfig(BaseSettings):
"""
ETL_TYPE: str = Field(
description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ',
default='dify',
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
default="dify",
)
KEYWORD_DATA_SOURCE_TYPE: str = Field(
description='source type for keyword data, default to `database`, available values are `database` .',
default='database',
description="source type for keyword data, default to `database`, available values are `database` .",
default="database",
)
UNSTRUCTURED_API_URL: Optional[str] = Field(
description='API URL for Unstructured',
description="API URL for Unstructured",
default=None,
)
UNSTRUCTURED_API_KEY: Optional[str] = Field(
description='API key for Unstructured',
description="API key for Unstructured",
default=None,
)
@@ -397,12 +497,12 @@ class DataSetConfig(BaseSettings):
"""
CLEAN_DAY_SETTING: PositiveInt = Field(
description='interval in days for cleaning up dataset',
description="interval in days for cleaning up dataset",
default=30,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description='whether to enable dataset operator',
description="whether to enable dataset operator",
default=False,
)
@@ -413,7 +513,7 @@ class WorkspaceConfig(BaseSettings):
"""
INVITE_EXPIRY_HOURS: PositiveInt = Field(
description='workspaces invitation expiration in hours',
description="workspaces invitation expiration in hours",
default=72,
)
@@ -424,80 +524,79 @@ class IndexingConfig(BaseSettings):
"""
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
description='max segmentation token length for indexing',
description="max segmentation token length for indexing",
default=1000,
)
class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
description='multi model send image format, support base64, url, default is base64',
default='base64',
description="multi model send image format, support base64, url, default is base64",
default="base64",
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description='the time of the celery scheduler, default to 1 day',
description="the time of the celery scheduler, default to 1 day",
default=1,
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description='The heads of model providers',
default='',
description="The heads of model providers",
default="",
)
POSITION_PROVIDER_INCLUDES: str = Field(
description='The included model providers',
default='',
description="The included model providers",
default="",
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description='The excluded model providers',
default='',
description="The excluded model providers",
default="",
)
POSITION_TOOL_PINS: str = Field(
description='The heads of tools',
default='',
description="The heads of tools",
default="",
)
POSITION_TOOL_INCLUDES: str = Field(
description='The included tools',
default='',
description="The included tools",
default="",
)
POSITION_TOOL_EXCLUDES: str = Field(
description='The excluded tools',
default='',
description="The excluded tools",
default="",
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class FeatureConfig(
@@ -525,7 +624,6 @@ class FeatureConfig(
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,

View File

@@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings):
"""
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description='',
default='gpt-3.5-turbo,'
'gpt-3.5-turbo-1106,'
'gpt-3.5-turbo-instruct,'
'gpt-3.5-turbo-16k,'
'gpt-3.5-turbo-16k-0613,'
'gpt-3.5-turbo-0613,'
'gpt-3.5-turbo-0125,'
'text-davinci-003',
description="",
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description='',
description="",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_OPENAI_PAID_MODELS: str = Field(
description='',
default='gpt-4,'
'gpt-4-turbo-preview,'
'gpt-4-turbo-2024-04-09,'
'gpt-4-1106-preview,'
'gpt-4-0125-preview,'
'gpt-3.5-turbo,'
'gpt-3.5-turbo-16k,'
'gpt-3.5-turbo-16k-0613,'
'gpt-3.5-turbo-1106,'
'gpt-3.5-turbo-0613,'
'gpt-3.5-turbo-0125,'
'gpt-3.5-turbo-instruct,'
'text-davinci-003',
description="",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003",
)
@@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings):
"""
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description='',
description="",
default=200,
)
@@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings):
"""
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
description='',
description="",
default=None,
)
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description='',
description="",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description='',
description="",
default=False,
)
@@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings):
"""
HOSTED_MINIMAX_ENABLED: bool = Field(
description='',
description="",
default=False,
)
@@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings):
"""
HOSTED_SPARK_ENABLED: bool = Field(
description='',
description="",
default=False,
)
@@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings):
"""
HOSTED_ZHIPUAI_ENABLED: bool = Field(
description='',
description="",
default=False,
)
@@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings):
"""
HOSTED_MODERATION_ENABLED: bool = Field(
description='',
description="",
default=False,
)
HOSTED_MODERATION_PROVIDERS: str = Field(
description='',
default='',
description="",
default="",
)
@@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings):
"""
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description='the mode for fetching app templates,'
' default to remote,'
' available values: remote, db, builtin',
default='remote',
description="the mode for fetching app templates,"
" default to remote,"
" available values: remote, db, builtin",
default="remote",
)
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
description='the domain for fetching remote app templates',
default='https://tmpl.dify.ai',
description="the domain for fetching remote app templates",
default="https://tmpl.dify.ai",
)
@@ -202,7 +202,6 @@ class HostedServiceConfig(
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
# moderation
HostedModerationConfig,
):

View File

@@ -13,6 +13,7 @@ from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
@@ -28,108 +29,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
description='storage type,'
' default to `local`,'
' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.',
default='local',
description="storage type,"
" default to `local`,"
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
default="local",
)
STORAGE_LOCAL_PATH: str = Field(
description='local storage path',
default='storage',
description="local storage path",
default="storage",
)
class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field(
description='vector store type',
description="vector store type",
default=None,
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
description='keyword store type',
default='jieba',
description="keyword store type",
default="jieba",
)
class DatabaseConfig:
DB_HOST: str = Field(
description='db host',
default='localhost',
description="db host",
default="localhost",
)
DB_PORT: PositiveInt = Field(
description='db port',
description="db port",
default=5432,
)
DB_USERNAME: str = Field(
description='db username',
default='postgres',
description="db username",
default="postgres",
)
DB_PASSWORD: str = Field(
description='db password',
default='',
description="db password",
default="",
)
DB_DATABASE: str = Field(
description='db database',
default='dify',
description="db database",
default="dify",
)
DB_CHARSET: str = Field(
description='db charset',
default='',
description="db charset",
default="",
)
DB_EXTRAS: str = Field(
description='db extras options. Example: keepalives_idle=60&keepalives=1',
default='',
description="db extras options. Example: keepalives_idle=60&keepalives=1",
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description='db uri scheme',
default='postgresql',
description="db uri scheme",
default="postgresql",
)
@computed_field
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
if self.DB_CHARSET
else self.DB_EXTRAS
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
).strip("&")
db_extras = f"?{db_extras}" if db_extras else ""
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}")
return (
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}"
)
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
description='pool size of SqlAlchemy',
description="pool size of SqlAlchemy",
default=30,
)
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
description='max overflows for SqlAlchemy',
description="max overflows for SqlAlchemy",
default=10,
)
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
description='SqlAlchemy pool recycle',
description="SqlAlchemy pool recycle",
default=3600,
)
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description='whether to enable pool pre-ping in SqlAlchemy',
description="whether to enable pool pre-ping in SqlAlchemy",
default=False,
)
SQLALCHEMY_ECHO: bool | str = Field(
description='whether to enable SqlAlchemy echo',
description="whether to enable SqlAlchemy echo",
default=False,
)
@@ -137,35 +138,38 @@ class DatabaseConfig:
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return {
'pool_size': self.SQLALCHEMY_POOL_SIZE,
'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW,
'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE,
'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING,
'connect_args': {'options': '-c timezone=UTC'},
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": {"options": "-c timezone=UTC"},
}
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description='Celery backend, available values are `database`, `redis`',
default='database',
description="Celery backend, available values are `database`, `redis`",
default="database",
)
CELERY_BROKER_URL: Optional[str] = Field(
description='CELERY_BROKER_URL',
description="CELERY_BROKER_URL",
default=None,
)
@computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str | None:
return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
return (
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
if self.CELERY_BACKEND == "database"
else self.CELERY_BROKER_URL
)
@computed_field
@property
def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
class MiddlewareConfig(
@@ -174,7 +178,6 @@ class MiddlewareConfig(
DatabaseConfig,
KeywordStoreConfig,
RedisConfig,
# configs of storage and storage providers
StorageConfig,
AliyunOSSStorageConfig,
@@ -183,7 +186,6 @@ class MiddlewareConfig(
TencentCloudCOSStorageConfig,
S3StorageConfig,
OCIStorageConfig,
# configs of vdb and vdb providers
VectorStoreConfig,
AnalyticdbConfig,
@@ -199,5 +201,6 @@ class MiddlewareConfig(
TencentVectorDBConfig,
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
):
pass

View File

@@ -8,32 +8,33 @@ class RedisConfig(BaseSettings):
"""
Redis configs
"""
REDIS_HOST: str = Field(
description='Redis host',
default='localhost',
description="Redis host",
default="localhost",
)
REDIS_PORT: PositiveInt = Field(
description='Redis port',
description="Redis port",
default=6379,
)
REDIS_USERNAME: Optional[str] = Field(
description='Redis username',
description="Redis username",
default=None,
)
REDIS_PASSWORD: Optional[str] = Field(
description='Redis password',
description="Redis password",
default=None,
)
REDIS_DB: NonNegativeInt = Field(
description='Redis database id, default to 0',
description="Redis database id, default to 0",
default=0,
)
REDIS_USE_SSL: bool = Field(
description='whether to use SSL for Redis connection',
description="whether to use SSL for Redis connection",
default=False,
)

View File

@@ -10,31 +10,36 @@ class AliyunOSSStorageConfig(BaseSettings):
"""
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
description='Aliyun OSS bucket name',
description="Aliyun OSS bucket name",
default=None,
)
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
description='Aliyun OSS access key',
description="Aliyun OSS access key",
default=None,
)
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
description='Aliyun OSS secret key',
description="Aliyun OSS secret key",
default=None,
)
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
description='Aliyun OSS endpoint URL',
description="Aliyun OSS endpoint URL",
default=None,
)
ALIYUN_OSS_REGION: Optional[str] = Field(
description='Aliyun OSS region',
description="Aliyun OSS region",
default=None,
)
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
description='Aliyun OSS authentication version',
description="Aliyun OSS authentication version",
default=None,
)
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Aliyun OSS path",
default=None,
)

View File

@@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings):
"""
S3_ENDPOINT: Optional[str] = Field(
description='S3 storage endpoint',
description="S3 storage endpoint",
default=None,
)
S3_REGION: Optional[str] = Field(
description='S3 storage region',
description="S3 storage region",
default=None,
)
S3_BUCKET_NAME: Optional[str] = Field(
description='S3 storage bucket name',
description="S3 storage bucket name",
default=None,
)
S3_ACCESS_KEY: Optional[str] = Field(
description='S3 storage access key',
description="S3 storage access key",
default=None,
)
S3_SECRET_KEY: Optional[str] = Field(
description='S3 storage secret key',
description="S3 storage secret key",
default=None,
)
S3_ADDRESS_STYLE: str = Field(
description='S3 storage address style',
default='auto',
description="S3 storage address style",
default="auto",
)
S3_USE_AWS_MANAGED_IAM: bool = Field(
description='whether to use aws managed IAM for S3',
description="whether to use aws managed IAM for S3",
default=False,
)

View File

@@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings):
"""
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
description='Azure Blob account name',
description="Azure Blob account name",
default=None,
)
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
description='Azure Blob account key',
description="Azure Blob account key",
default=None,
)
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
description='Azure Blob container name',
description="Azure Blob container name",
default=None,
)
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
description='Azure Blob account URL',
description="Azure Blob account URL",
default=None,
)

View File

@@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings):
"""
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
description='Google Cloud storage bucket name',
description="Google Cloud storage bucket name",
default=None,
)
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
description='Google Cloud storage service account json base64',
description="Google Cloud storage service account json base64",
default=None,
)

View File

@@ -10,27 +10,26 @@ class OCIStorageConfig(BaseSettings):
"""
OCI_ENDPOINT: Optional[str] = Field(
description='OCI storage endpoint',
description="OCI storage endpoint",
default=None,
)
OCI_REGION: Optional[str] = Field(
description='OCI storage region',
description="OCI storage region",
default=None,
)
OCI_BUCKET_NAME: Optional[str] = Field(
description='OCI storage bucket name',
description="OCI storage bucket name",
default=None,
)
OCI_ACCESS_KEY: Optional[str] = Field(
description='OCI storage access key',
description="OCI storage access key",
default=None,
)
OCI_SECRET_KEY: Optional[str] = Field(
description='OCI storage secret key',
description="OCI storage secret key",
default=None,
)

View File

@@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings):
"""
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
description='Tencent Cloud COS bucket name',
description="Tencent Cloud COS bucket name",
default=None,
)
TENCENT_COS_REGION: Optional[str] = Field(
description='Tencent Cloud COS region',
description="Tencent Cloud COS region",
default=None,
)
TENCENT_COS_SECRET_ID: Optional[str] = Field(
description='Tencent Cloud COS secret id',
description="Tencent Cloud COS secret id",
default=None,
)
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
description='Tencent Cloud COS secret key',
description="Tencent Cloud COS secret key",
default=None,
)
TENCENT_COS_SCHEME: Optional[str] = Field(
description='Tencent Cloud COS scheme',
description="Tencent Cloud COS scheme",
default=None,
)

View File

@@ -10,35 +10,28 @@ class AnalyticdbConfig(BaseModel):
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID : Optional[str] = Field(
default=None,
description="The Access Key ID provided by Alibaba Cloud for authentication."
ANALYTICDB_KEY_ID: Optional[str] = Field(
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
default=None,
description="The Secret Access Key corresponding to the Access Key ID for secure access."
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID : Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
ANALYTICDB_REGION_ID: Optional[str] = Field(
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
)
ANALYTICDB_ACCOUNT : Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance."
ANALYTICDB_ACCOUNT: Optional[str] = Field(
default=None, description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD : Optional[str] = Field(
default=None,
description="The password associated with the AnalyticDB account for authentication."
ANALYTICDB_PASSWORD: Optional[str] = Field(
default=None, description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE : Optional[str] = Field(
default=None,
description="The namespace within AnalyticDB for schema isolation."
ANALYTICDB_NAMESPACE: Optional[str] = Field(
default=None, description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance."
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
)

View File

@@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings):
"""
CHROMA_HOST: Optional[str] = Field(
description='Chroma host',
description="Chroma host",
default=None,
)
CHROMA_PORT: PositiveInt = Field(
description='Chroma port',
description="Chroma port",
default=8000,
)
CHROMA_TENANT: Optional[str] = Field(
description='Chroma database',
description="Chroma database",
default=None,
)
CHROMA_DATABASE: Optional[str] = Field(
description='Chroma database',
description="Chroma database",
default=None,
)
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
description='Chroma authentication provider',
description="Chroma authentication provider",
default=None,
)
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
description='Chroma authentication credentials',
description="Chroma authentication credentials",
default=None,
)

View File

@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class ElasticsearchConfig(BaseSettings):
"""
Elasticsearch configs
"""
ELASTICSEARCH_HOST: Optional[str] = Field(
description="Elasticsearch host",
default="127.0.0.1",
)
ELASTICSEARCH_PORT: PositiveInt = Field(
description="Elasticsearch port",
default=9200,
)
ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Elasticsearch username",
default="elastic",
)
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Elasticsearch password",
default="elastic",
)

View File

@@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings):
"""
MILVUS_HOST: Optional[str] = Field(
description='Milvus host',
description="Milvus host",
default=None,
)
MILVUS_PORT: PositiveInt = Field(
description='Milvus RestFul API port',
description="Milvus RestFul API port",
default=9091,
)
MILVUS_USER: Optional[str] = Field(
description='Milvus user',
description="Milvus user",
default=None,
)
MILVUS_PASSWORD: Optional[str] = Field(
description='Milvus password',
description="Milvus password",
default=None,
)
MILVUS_SECURE: bool = Field(
description='whether to use SSL connection for Milvus',
description="whether to use SSL connection for Milvus",
default=False,
)
MILVUS_DATABASE: str = Field(
description='Milvus database, default to `default`',
default='default',
description="Milvus database, default to `default`",
default="default",
)

View File

@@ -1,4 +1,3 @@
from pydantic import BaseModel, Field, PositiveInt
@@ -8,31 +7,31 @@ class MyScaleConfig(BaseModel):
"""
MYSCALE_HOST: str = Field(
description='MyScale host',
default='localhost',
description="MyScale host",
default="localhost",
)
MYSCALE_PORT: PositiveInt = Field(
description='MyScale port',
description="MyScale port",
default=8123,
)
MYSCALE_USER: str = Field(
description='MyScale user',
default='default',
description="MyScale user",
default="default",
)
MYSCALE_PASSWORD: str = Field(
description='MyScale password',
default='',
description="MyScale password",
default="",
)
MYSCALE_DATABASE: str = Field(
description='MyScale database name',
default='default',
description="MyScale database name",
default="default",
)
MYSCALE_FTS_PARAMS: str = Field(
description='MyScale fts index parameters',
default='',
description="MyScale fts index parameters",
default="",
)

View File

@@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings):
"""
OPENSEARCH_HOST: Optional[str] = Field(
description='OpenSearch host',
description="OpenSearch host",
default=None,
)
OPENSEARCH_PORT: PositiveInt = Field(
description='OpenSearch port',
description="OpenSearch port",
default=9200,
)
OPENSEARCH_USER: Optional[str] = Field(
description='OpenSearch user',
description="OpenSearch user",
default=None,
)
OPENSEARCH_PASSWORD: Optional[str] = Field(
description='OpenSearch password',
description="OpenSearch password",
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description='whether to use SSL connection for OpenSearch',
description="whether to use SSL connection for OpenSearch",
default=False,
)

View File

@@ -10,26 +10,26 @@ class OracleConfig(BaseSettings):
"""
ORACLE_HOST: Optional[str] = Field(
description='ORACLE host',
description="ORACLE host",
default=None,
)
ORACLE_PORT: Optional[PositiveInt] = Field(
description='ORACLE port',
description="ORACLE port",
default=1521,
)
ORACLE_USER: Optional[str] = Field(
description='ORACLE user',
description="ORACLE user",
default=None,
)
ORACLE_PASSWORD: Optional[str] = Field(
description='ORACLE password',
description="ORACLE password",
default=None,
)
ORACLE_DATABASE: Optional[str] = Field(
description='ORACLE database',
description="ORACLE database",
default=None,
)

View File

@@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings):
"""
PGVECTOR_HOST: Optional[str] = Field(
description='PGVector host',
description="PGVector host",
default=None,
)
PGVECTOR_PORT: Optional[PositiveInt] = Field(
description='PGVector port',
description="PGVector port",
default=5433,
)
PGVECTOR_USER: Optional[str] = Field(
description='PGVector user',
description="PGVector user",
default=None,
)
PGVECTOR_PASSWORD: Optional[str] = Field(
description='PGVector password',
description="PGVector password",
default=None,
)
PGVECTOR_DATABASE: Optional[str] = Field(
description='PGVector database',
description="PGVector database",
default=None,
)

View File

@@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings):
"""
PGVECTO_RS_HOST: Optional[str] = Field(
description='PGVectoRS host',
description="PGVectoRS host",
default=None,
)
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
description='PGVectoRS port',
description="PGVectoRS port",
default=5431,
)
PGVECTO_RS_USER: Optional[str] = Field(
description='PGVectoRS user',
description="PGVectoRS user",
default=None,
)
PGVECTO_RS_PASSWORD: Optional[str] = Field(
description='PGVectoRS password',
description="PGVectoRS password",
default=None,
)
PGVECTO_RS_DATABASE: Optional[str] = Field(
description='PGVectoRS database',
description="PGVectoRS database",
default=None,
)

View File

@@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings):
"""
QDRANT_URL: Optional[str] = Field(
description='Qdrant url',
description="Qdrant url",
default=None,
)
QDRANT_API_KEY: Optional[str] = Field(
description='Qdrant api key',
description="Qdrant api key",
default=None,
)
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description='Qdrant client timeout in seconds',
description="Qdrant client timeout in seconds",
default=20,
)
QDRANT_GRPC_ENABLED: bool = Field(
description='whether enable grpc support for Qdrant connection',
description="whether enable grpc support for Qdrant connection",
default=False,
)
QDRANT_GRPC_PORT: PositiveInt = Field(
description='Qdrant grpc port',
description="Qdrant grpc port",
default=6334,
)

View File

@@ -10,26 +10,26 @@ class RelytConfig(BaseSettings):
"""
RELYT_HOST: Optional[str] = Field(
description='Relyt host',
description="Relyt host",
default=None,
)
RELYT_PORT: PositiveInt = Field(
description='Relyt port',
description="Relyt port",
default=9200,
)
RELYT_USER: Optional[str] = Field(
description='Relyt user',
description="Relyt user",
default=None,
)
RELYT_PASSWORD: Optional[str] = Field(
description='Relyt password',
description="Relyt password",
default=None,
)
RELYT_DATABASE: Optional[str] = Field(
description='Relyt database',
default='default',
description="Relyt database",
default="default",
)

View File

@@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings):
"""
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
description='Tencent Vector URL',
description="Tencent Vector URL",
default=None,
)
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
description='Tencent Vector API key',
description="Tencent Vector API key",
default=None,
)
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
description='Tencent Vector timeout in seconds',
description="Tencent Vector timeout in seconds",
default=30,
)
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description='Tencent Vector username',
description="Tencent Vector username",
default=None,
)
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
description='Tencent Vector password',
description="Tencent Vector password",
default=None,
)
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
description='Tencent Vector sharding number',
description="Tencent Vector sharding number",
default=1,
)
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description='Tencent Vector replicas',
description="Tencent Vector replicas",
default=2,
)
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description='Tencent Vector Database',
description="Tencent Vector Database",
default=None,
)

View File

@@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings):
"""
TIDB_VECTOR_HOST: Optional[str] = Field(
description='TiDB Vector host',
description="TiDB Vector host",
default=None,
)
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
description='TiDB Vector port',
description="TiDB Vector port",
default=4000,
)
TIDB_VECTOR_USER: Optional[str] = Field(
description='TiDB Vector user',
description="TiDB Vector user",
default=None,
)
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
description='TiDB Vector password',
description="TiDB Vector password",
default=None,
)
TIDB_VECTOR_DATABASE: Optional[str] = Field(
description='TiDB Vector database',
description="TiDB Vector database",
default=None,
)

View File

@@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings):
"""
WEAVIATE_ENDPOINT: Optional[str] = Field(
description='Weaviate endpoint URL',
description="Weaviate endpoint URL",
default=None,
)
WEAVIATE_API_KEY: Optional[str] = Field(
description='Weaviate API key',
description="Weaviate API key",
default=None,
)
WEAVIATE_GRPC_ENABLED: bool = Field(
description='whether to enable gRPC for Weaviate connection',
description="whether to enable gRPC for Weaviate connection",
default=True,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description='Weaviate batch size',
description="Weaviate batch size",
default=100,
)

View File

@@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings):
"""
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.7.1',
description="Dify version",
default="0.8.0-beta1",
)
COMMIT_SHA: str = Field(
description="SHA-1 checksum of the git commit used to build the app",
default='',
default="",
)

View File

@@ -1,3 +1 @@

View File

@@ -2,7 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi
bp = Blueprint('console', __name__, url_prefix='/console/api')
bp = Blueprint("console", __name__, url_prefix="/console/api")
api = ExternalApi(bp)
# Import other controllers

View File

@@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp
def admin_required(view):
@wraps(view)
def decorated(*args, **kwargs):
if not os.getenv('ADMIN_API_KEY'):
raise Unauthorized('API key is invalid.')
if not os.getenv("ADMIN_API_KEY"):
raise Unauthorized("API key is invalid.")
auth_header = request.headers.get('Authorization')
auth_header = request.headers.get("Authorization")
if auth_header is None:
raise Unauthorized('Authorization header is missing.')
raise Unauthorized("Authorization header is missing.")
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if os.getenv('ADMIN_API_KEY') != auth_token:
raise Unauthorized('API key is invalid.')
if os.getenv("ADMIN_API_KEY") != auth_token:
raise Unauthorized("API key is invalid.")
return view(*args, **kwargs)
@@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource):
@admin_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('app_id', type=str, required=True, nullable=False, location='json')
parser.add_argument('desc', type=str, location='json')
parser.add_argument('copyright', type=str, location='json')
parser.add_argument('privacy_policy', type=str, location='json')
parser.add_argument('custom_disclaimer', type=str, location='json')
parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument('position', type=int, required=True, nullable=False, location='json')
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("desc", type=str, location="json")
parser.add_argument("copyright", type=str, location="json")
parser.add_argument("privacy_policy", type=str, location="json")
parser.add_argument("custom_disclaimer", type=str, location="json")
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
app = App.query.filter(App.id == args['app_id']).first()
app = App.query.filter(App.id == args["app_id"]).first()
if not app:
raise NotFound(f'App \'{args["app_id"]}\' is not found')
site = app.site
if not site:
desc = args['desc'] if args['desc'] else ''
copy_right = args['copyright'] if args['copyright'] else ''
privacy_policy = args['privacy_policy'] if args['privacy_policy'] else ''
custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else ''
desc = args["desc"] if args["desc"] else ""
copy_right = args["copyright"] if args["copyright"] else ""
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
else:
desc = site.description if site.description else \
args['desc'] if args['desc'] else ''
copy_right = site.copyright if site.copyright else \
args['copyright'] if args['copyright'] else ''
privacy_policy = site.privacy_policy if site.privacy_policy else \
args['privacy_policy'] if args['privacy_policy'] else ''
custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \
args['custom_disclaimer'] if args['custom_disclaimer'] else ''
desc = site.description if site.description else args["desc"] if args["desc"] else ""
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
privacy_policy = (
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
if not recommended_app:
recommended_app = RecommendedApp(
@@ -83,9 +87,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
language=args['language'],
category=args['category'],
position=args['position']
language=args["language"],
category=args["category"],
position=args["position"],
)
db.session.add(recommended_app)
@@ -93,21 +97,21 @@ class InsertExploreAppListApi(Resource):
app.is_public = True
db.session.commit()
return {'result': 'success'}, 201
return {"result": "success"}, 201
else:
recommended_app.description = desc
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args['language']
recommended_app.category = args['category']
recommended_app.position = args['position']
recommended_app.language = args["language"]
recommended_app.category = args["category"]
recommended_app.position = args["position"]
app.is_public = True
db.session.commit()
return {'result': 'success'}, 200
return {"result": "success"}, 200
class InsertExploreAppApi(Resource):
@@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource):
def delete(self, app_id):
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
if not recommended_app:
return {'result': 'success'}, 204
return {"result": "success"}, 204
app = App.query.filter(App.id == recommended_app.app_id).first()
if app:
app.is_public = False
installed_apps = InstalledApp.query.filter(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
).all()
for installed_app in installed_apps:
@@ -133,8 +136,8 @@ class InsertExploreAppApi(Resource):
db.session.delete(recommended_app)
db.session.commit()
return {'result': 'success'}, 204
return {"result": "success"}, 204
api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps')
api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/<uuid:app_id>')
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")

View File

@@ -14,26 +14,21 @@ from .setup import setup_required
from .wraps import account_initialization_required
api_key_fields = {
'id': fields.String,
'type': fields.String,
'token': fields.String,
'last_used_at': TimestampField,
'created_at': TimestampField
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
}
api_key_list = {
'data': fields.List(fields.Nested(api_key_fields), attribute="items")
}
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
def _get_resource(resource_id, tenant_id, resource_model):
resource = resource_model.query.filter_by(
id=resource_id, tenant_id=tenant_id
).first()
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
if resource is None:
flask_restful.abort(
404, message=f"{resource_model.__name__} not found.")
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
return resource
@@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource):
@marshal_with(api_key_list)
def get(self, resource_id):
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id,
self.resource_model)
keys = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
all()
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
return {"items": keys}
@marshal_with(api_key_fields)
def post(self, resource_id):
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id,
self.resource_model)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
count()
current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
)
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code='max_keys_exceeded'
code="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
@@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource):
def delete(self, resource_id, api_key_id):
resource_id = str(resource_id)
api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id,
self.resource_model)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
key = db.session.query(ApiToken). \
filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \
first()
key = (
db.session.query(ApiToken)
.filter(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
)
if key is None:
flask_restful.abort(404, message='API key not found')
flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {'result': 'success'}, 204
return {"result": "success"}, 204
class AppApiKeyListResource(BaseApiKeyListResource):
def after_request(self, resp):
resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers['Access-Control-Allow-Credentials'] = 'true'
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = 'app'
resource_type = "app"
resource_model = App
resource_id_field = 'app_id'
token_prefix = 'app-'
resource_id_field = "app_id"
token_prefix = "app-"
class AppApiKeyResource(BaseApiKeyResource):
def after_request(self, resp):
resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers['Access-Control-Allow-Credentials'] = 'true'
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = 'app'
resource_type = "app"
resource_model = App
resource_id_field = 'app_id'
resource_id_field = "app_id"
class DatasetApiKeyListResource(BaseApiKeyListResource):
def after_request(self, resp):
resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers['Access-Control-Allow-Credentials'] = 'true'
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = 'dataset'
resource_type = "dataset"
resource_model = Dataset
resource_id_field = 'dataset_id'
token_prefix = 'ds-'
resource_id_field = "dataset_id"
token_prefix = "ds-"
class DatasetApiKeyResource(BaseApiKeyResource):
def after_request(self, resp):
resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers['Access-Control-Allow-Credentials'] = 'true'
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = 'dataset'
resource_type = "dataset"
resource_model = Dataset
resource_id_field = 'dataset_id'
resource_id_field = "dataset_id"
api.add_resource(AppApiKeyListResource, '/apps/<uuid:resource_id>/api-keys')
api.add_resource(AppApiKeyResource,
'/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiKeyListResource,
'/datasets/<uuid:resource_id>/api-keys')
api.add_resource(DatasetApiKeyResource,
'/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")

View File

@@ -8,19 +8,18 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
parser.add_argument("app_mode", type=str, required=True, location="args")
parser.add_argument("model_mode", type=str, required=True, location="args")
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
parser.add_argument("model_name", type=str, required=True, location="args")
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")

View File

@@ -18,15 +18,12 @@ class AgentLogApi(Resource):
def get(self, app_model):
"""Get agent logs"""
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=uuid_value, required=True, location='args')
parser.add_argument('conversation_id', type=uuid_value, required=True, location='args')
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
args = parser.parse_args()
return AgentService.get_agent_logs(
app_model,
args['conversation_id'],
args['message_id']
)
api.add_resource(AgentLogApi, '/apps/<uuid:app_id>/agent/logs')
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")

View File

@@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@cloud_edition_billing_resource_check("annotation")
def post(self, app_id, action):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
if action == 'enable':
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == 'disable':
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError('Unsupported annotation reply action')
raise ValueError("Unsupported annotation reply action")
return result, 200
@@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource):
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
parser.add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
@@ -77,28 +77,24 @@ class AnnotationReplyActionStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@cloud_edition_billing_resource_check("annotation")
def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
error_msg = ""
if job_status == "error":
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationListApi(Resource):
@@ -109,18 +105,18 @@ class AnnotationListApi(Resource):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
keyword = request.args.get('keyword', default=None, type=str)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default=None, type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
'data': marshal(annotation_list, annotation_fields),
'has_more': len(annotation_list) == limit,
'limit': limit,
'total': total,
'page': page
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@@ -135,9 +131,7 @@ class AnnotationExportApi(Resource):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {
'data': marshal(annotation_list, annotation_fields)
}
response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
@@ -145,7 +139,7 @@ class AnnotationCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
def post(self, app_id):
if not current_user.is_editor:
@@ -153,8 +147,8 @@ class AnnotationCreateApi(Resource):
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
@@ -164,7 +158,7 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
if not current_user.is_editor:
@@ -173,8 +167,8 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@@ -189,29 +183,29 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
class AnnotationBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@cloud_edition_billing_resource_check("annotation")
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files['file']
file = request.files["file"]
# check file
if 'file' not in request.files:
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
if not file.filename.endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
@@ -220,27 +214,23 @@ class AnnotationBatchImportStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@cloud_edition_billing_resource_check("annotation")
def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
error_msg = ""
if job_status == "error":
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
error_msg = redis_client.get(indexing_error_msg_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationHitHistoryListApi(Resource):
@@ -251,30 +241,32 @@ class AnnotationHitHistoryListApi(Resource):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id)
annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
page, limit)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
)
response = {
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
'has_more': len(annotation_hit_history_list) == limit,
'limit': limit,
'total': total,
'page': page
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
api.add_resource(AnnotationReplyActionStatusApi,
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
api.add_resource(
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
)
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")

View File

@@ -18,27 +18,35 @@ from libs.login import login_required
from services.app_dsl_service import AppDslService
from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
class AppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Get app list"""
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(',')]
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser()
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
parser.add_argument('name', type=str, location='args', required=False)
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"mode",
type=str,
choices=["chat", "workflow", "agent-chat", "channel", "all"],
default="all",
location="args",
required=False,
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
args = parser.parse_args()
@@ -46,7 +54,7 @@ class AppListApi(Resource):
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
if not app_pagination:
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
return marshal(app_pagination, app_pagination_fields)
@@ -54,23 +62,23 @@ class AppListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check('apps')
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if 'mode' not in args or args['mode'] is None:
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService()
@@ -84,7 +92,7 @@ class AppImportApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check('apps')
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app"""
# The role of the current user in the ta table must be admin, owner, or editor
@@ -92,19 +100,16 @@ class AppImportApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id,
data=args['data'],
args=args,
account=current_user
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
)
return app, 201
@@ -115,7 +120,7 @@ class AppImportFromUrlApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check('apps')
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor
@@ -123,25 +128,21 @@ class AppImportFromUrlApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id,
url=args['url'],
args=args,
account=current_user
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
)
return app, 201
class AppApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -165,14 +166,15 @@ class AppApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
args = parser.parse_args()
app_service = AppService()
@@ -193,7 +195,7 @@ class AppApi(Resource):
app_service = AppService()
app_service.delete_app(app_model)
return {'result': 'success'}, 204
return {"result": "success"}, 204
class AppCopyApi(Resource):
@@ -209,19 +211,16 @@ class AppCopyApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id,
data=data,
args=args,
account=current_user
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
)
return app, 201
@@ -240,12 +239,10 @@ class AppExportApi(Resource):
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
args = parser.parse_args()
return {
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
}
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
class AppNameApi(Resource):
@@ -258,13 +255,13 @@ class AppNameApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.get('name'))
app_model = app_service.update_app_name(app_model, args.get("name"))
return app_model
@@ -279,14 +276,14 @@ class AppIconApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background'))
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
return app_model
@@ -301,13 +298,13 @@ class AppSiteStatus(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('enable_site', type=bool, required=True, location='json')
parser.add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.get('enable_site'))
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
return app_model
@@ -322,13 +319,13 @@ class AppApiStatus(Resource):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('enable_api', type=bool, required=True, location='json')
parser.add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.get('enable_api'))
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
return app_model
@@ -339,9 +336,7 @@ class AppTraceApi(Resource):
@account_initialization_required
def get(self, app_id):
"""Get app trace"""
app_trace_config = OpsTraceManager.get_app_tracing_config(
app_id=app_id
)
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
return app_trace_config
@@ -353,27 +348,27 @@ class AppTraceApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('enabled', type=bool, required=True, location='json')
parser.add_argument('tracing_provider', type=str, required=True, location='json')
parser.add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json")
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
enabled=args['enabled'],
tracing_provider=args['tracing_provider'],
enabled=args["enabled"],
tracing_provider=args["tracing_provider"],
)
return {"result": "success"}
api.add_resource(AppListApi, '/apps')
api.add_resource(AppImportApi, '/apps/import')
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
api.add_resource(AppApi, '/apps/<uuid:app_id>')
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon')
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace')
api.add_resource(AppListApi, "/apps")
api.add_resource(AppImportApi, "/apps/import")
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
api.add_resource(AppApi, "/apps/<uuid:app_id>")
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")

View File

@@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model):
file = request.files['file']
file = request.files["file"]
try:
response = AudioService.transcript_asr(
@@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource):
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
parser.add_argument("message_id", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=text,
message_id=message_id,
voice=voice
)
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
@@ -145,12 +145,12 @@ class TextModesApi(Resource):
def get(self, app_model):
try:
parser = reqparse.RequestParser()
parser.add_argument('language', type=str, required=True, location='args')
parser.add_argument("language", type=str, required=True, location="args")
args = parser.parse_args()
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args['language'],
language=args["language"],
)
return response
@@ -179,6 +179,6 @@ class TextModesApi(Resource):
raise InternalServerError()
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")

View File

@@ -17,6 +17,7 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@@ -31,37 +32,33 @@ from libs.helper import uuid_value
from libs.login import login_required
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
# define completion message api for user
class CompletionMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
args['auto_generate_name'] = False
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
account = flask_login.current_user
try:
response = AppGenerateService.generate(
app_model=app_model,
user=account,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
)
return helper.compact_generate_response(response)
@@ -97,7 +94,7 @@ class CompletionMessageStopApi(Resource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
class ChatMessageApi(Resource):
@@ -107,27 +104,23 @@ class ChatMessageApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
args['auto_generate_name'] = False
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
account = flask_login.current_user
try:
response = AppGenerateService.generate(
app_model=app_model,
user=account,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
)
return helper.compact_generate_response(response)
@@ -144,6 +137,8 @@ class ChatMessageApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
@@ -163,10 +158,10 @@ class ChatMessageStopApi(Resource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
api.add_resource(CompletionMessageApi, '/apps/<uuid:app_id>/completion-messages')
api.add_resource(CompletionMessageStopApi, '/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop')
api.add_resource(ChatMessageApi, '/apps/<uuid:app_id>/chat-messages')
api.add_resource(ChatMessageStopApi, '/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop')
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")

View File

@@ -26,7 +26,6 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat
class CompletionConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -36,24 +35,23 @@ class CompletionConversationApi(Resource):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args')
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('annotation_status', type=str,
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion')
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
if args['keyword']:
query = query.join(
Message, Message.conversation_id == Conversation.id
).filter(
if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).filter(
or_(
Message.query.ilike('%{}%'.format(args['keyword'])),
Message.answer.ilike('%{}%'.format(args['keyword']))
Message.query.ilike("%{}%".format(args["keyword"])),
Message.answer.ilike("%{}%".format(args["keyword"])),
)
)
@@ -61,8 +59,8 @@ class CompletionConversationApi(Resource):
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
@@ -70,8 +68,8 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at >= start_datetime_utc)
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
@@ -79,29 +77,25 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc)
if args['annotation_status'] == "annotated":
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args['annotation_status'] == "not_annotated":
query = query.outerjoin(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(
query,
page=args['page'],
per_page=args['limit'],
error_out=False
)
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations
class CompletionConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -123,8 +117,11 @@ class CompletionConversationDetailApi(Resource):
raise Forbidden()
conversation_id = str(conversation_id)
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
@@ -132,11 +129,10 @@ class CompletionConversationDetailApi(Resource):
conversation.is_deleted = True
db.session.commit()
return {'result': 'success'}, 204
return {"result": "success"}, 204
class ChatConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -146,22 +142,28 @@ class ChatConversationApi(Resource):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args')
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('annotation_status', type=str,
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
args = parser.parse_args()
subquery = (
db.session.query(
Conversation.id.label('conversation_id'),
EndUser.session_id.label('from_end_user_session_id')
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
)
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery()
@@ -169,28 +171,31 @@ class ChatConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
if args['keyword']:
keyword_filter = '%{}%'.format(args['keyword'])
query = query.join(
Message, Message.conversation_id == Conversation.id,
).join(
subquery, subquery.c.conversation_id == Conversation.id
).filter(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter)
),
if args["keyword"]:
keyword_filter = "%{}%".format(args["keyword"])
query = (
query.join(
Message,
Message.conversation_id == Conversation.id,
)
.join(subquery, subquery.c.conversation_id == Conversation.id)
.filter(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
),
)
)
account = current_user
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
@@ -198,8 +203,8 @@ class ChatConversationApi(Resource):
query = query.where(Conversation.created_at >= start_datetime_utc)
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
@@ -207,50 +212,46 @@ class ChatConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc)
if args['annotation_status'] == "annotated":
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args['annotation_status'] == "not_annotated":
query = query.outerjoin(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
if args['message_count_gte'] and args['message_count_gte'] >= 1:
if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = (
query.options(joinedload(Conversation.messages))
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args['message_count_gte'])
.having(func.count(Message.id) >= args["message_count_gte"])
)
if app_model.mode == AppMode.ADVANCED_CHAT.value:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
match args['sort_by']:
case 'created_at':
match args["sort_by"]:
case "created_at":
query = query.order_by(Conversation.created_at.asc())
case '-created_at':
case "-created_at":
query = query.order_by(Conversation.created_at.desc())
case 'updated_at':
case "updated_at":
query = query.order_by(Conversation.updated_at.asc())
case '-updated_at':
case "-updated_at":
query = query.order_by(Conversation.updated_at.desc())
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(
query,
page=args['page'],
per_page=args['limit'],
error_out=False
)
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations
class ChatConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -272,8 +273,11 @@ class ChatConversationDetailApi(Resource):
raise Forbidden()
conversation_id = str(conversation_id)
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
@@ -281,18 +285,21 @@ class ChatConversationDetailApi(Resource):
conversation.is_deleted = True
db.session.commit()
return {'result': 'success'}, 204
return {"result": "success"}, 204
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>')
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
def _get_conversation(app_model, conversation_id):
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")

View File

@@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource):
@marshal_with(paginated_conversation_variable_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', type=str, location='args')
parser.add_argument("conversation_id", type=str, location="args")
args = parser.parse_args()
stmt = (
@@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource):
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
if args['conversation_id']:
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
if args["conversation_id"]:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError('conversation_id is required')
raise ValueError("conversation_id is required")
# NOTE: This is a temporary solution to avoid performance issues.
page = 1
@@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource):
rows = session.scalars(stmt).all()
return {
'page': page,
'limit': page_size,
'total': len(rows),
'has_more': False,
'data': [
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
{
'created_at': row.created_at,
'updated_at': row.updated_at,
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
@@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource):
}
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")

View File

@@ -2,116 +2,128 @@ from libs.exception import BaseHTTPException
class AppNotFoundError(BaseHTTPException):
error_code = 'app_not_found'
error_code = "app_not_found"
description = "App not found."
code = 404
class ProviderNotInitializeError(BaseHTTPException):
error_code = 'provider_not_initialize'
description = "No valid model provider credentials found. " \
"Please go to Settings -> Model Provider to complete your provider credentials."
error_code = "provider_not_initialize"
description = (
"No valid model provider credentials found. "
"Please go to Settings -> Model Provider to complete your provider credentials."
)
code = 400
class ProviderQuotaExceededError(BaseHTTPException):
error_code = 'provider_quota_exceeded'
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
"Please go to Settings -> Model Provider to complete your own provider credentials."
error_code = "provider_quota_exceeded"
description = (
"Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
code = 400
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
error_code = 'model_currently_not_support'
error_code = "model_currently_not_support"
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
code = 400
class ConversationCompletedError(BaseHTTPException):
error_code = 'conversation_completed'
error_code = "conversation_completed"
description = "The conversation has ended. Please start a new conversation."
code = 400
class AppUnavailableError(BaseHTTPException):
error_code = 'app_unavailable'
error_code = "app_unavailable"
description = "App unavailable, please check your app configurations."
code = 400
class CompletionRequestError(BaseHTTPException):
error_code = 'completion_request_error'
error_code = "completion_request_error"
description = "Completion request failed."
code = 400
class AppMoreLikeThisDisabledError(BaseHTTPException):
error_code = 'app_more_like_this_disabled'
error_code = "app_more_like_this_disabled"
description = "The 'More like this' feature is disabled. Please refresh your page."
code = 403
class NoAudioUploadedError(BaseHTTPException):
error_code = 'no_audio_uploaded'
error_code = "no_audio_uploaded"
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = 'audio_too_large'
error_code = "audio_too_large"
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = 'unsupported_audio_type'
error_code = "unsupported_audio_type"
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
error_code = "provider_not_support_speech_to_text"
description = "Provider not support speech to text."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class DraftWorkflowNotExist(BaseHTTPException):
error_code = 'draft_workflow_not_exist'
error_code = "draft_workflow_not_exist"
description = "Draft workflow need to be initialized."
code = 400
class DraftWorkflowNotSync(BaseHTTPException):
error_code = 'draft_workflow_not_sync'
error_code = "draft_workflow_not_sync"
description = "Workflow graph might have been modified, please refresh and resubmit."
code = 400
class TracingConfigNotExist(BaseHTTPException):
error_code = 'trace_config_not_exist'
error_code = "trace_config_not_exist"
description = "Trace config not exist."
code = 400
class TracingConfigIsExist(BaseHTTPException):
error_code = 'trace_config_is_exist'
error_code = "trace_config_is_exist"
description = "Trace config is exist."
code = 400
class TracingConfigCheckError(BaseHTTPException):
error_code = 'trace_config_check_error'
error_code = "trace_config_check_error"
description = "Invalid Credentials."
code = 400
class InvokeRateLimitError(BaseHTTPException):
"""Raised when the Invoke returns rate limit error."""
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429

View File

@@ -24,21 +24,21 @@ class RuleGenerateApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('instruction', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json')
parser.add_argument('no_variable', type=bool, required=True, default=False, location='json')
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
args = parser.parse_args()
account = current_user
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512'))
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id,
instruction=args['instruction'],
model_config=args['model_config'],
no_variable=args['no_variable'],
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -52,4 +52,4 @@ class RuleGenerateApi(Resource):
return rules
api.add_resource(RuleGenerateApi, '/rule-generate')
api.add_resource(RuleGenerateApi, "/rule-generate")

View File

@@ -33,9 +33,9 @@ from services.message_service import MessageService
class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_detail_fields))
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_fields)),
}
@setup_required
@@ -45,55 +45,69 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
conversation = db.session.query(Conversation).filter(
Conversation.id == args['conversation_id'],
Conversation.app_id == app_model.id
).first()
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args['first_id']:
first_message = db.session.query(Message) \
.filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first()
if args["first_id"]:
first_message = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)
if not first_message:
raise NotFound("First message not found")
history_messages = db.session.query(Message).filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id
) \
.order_by(Message.created_at.desc()).limit(args['limit']).all()
history_messages = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
)
else:
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
.order_by(Message.created_at.desc()).limit(args['limit']).all()
history_messages = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
)
has_more = False
if len(history_messages) == args['limit']:
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
rest_count = db.session.query(Message).filter(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id
).count()
rest_count = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
.count()
)
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(
data=history_messages,
limit=args['limit'],
has_more=has_more
)
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
class MessageFeedbackApi(Resource):
@@ -103,49 +117,46 @@ class MessageFeedbackApi(Resource):
@get_app_model
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
message_id = str(args['message_id'])
message_id = str(args["message_id"])
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id
).first()
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
feedback = message.admin_feedback
if not args['rating'] and feedback:
if not args["rating"] and feedback:
db.session.delete(feedback)
elif args['rating'] and feedback:
feedback.rating = args['rating']
elif not args['rating'] and not feedback:
raise ValueError('rating cannot be None when feedback not exists')
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args['rating'],
from_source='admin',
from_account_id=current_user.id
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
)
db.session.add(feedback)
db.session.commit()
return {'result': 'success'}
return {"result": "success"}
class MessageAnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@cloud_edition_billing_resource_check("annotation")
@get_app_model
@marshal_with(annotation_fields)
def post(self, app_model):
@@ -153,10 +164,10 @@ class MessageAnnotationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
@@ -169,11 +180,9 @@ class MessageAnnotationCountApi(Resource):
@account_initialization_required
@get_app_model
def get(self, app_model):
count = db.session.query(MessageAnnotation).filter(
MessageAnnotation.app_id == app_model.id
).count()
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
return {'count': count}
return {"count": count}
class MessageSuggestedQuestionApi(Resource):
@@ -186,10 +195,7 @@ class MessageSuggestedQuestionApi(Resource):
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
message_id=message_id,
user=current_user,
invoke_from=InvokeFrom.DEBUGGER
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
)
except MessageNotExistsError:
raise NotFound("Message not found")
@@ -209,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
return {'data': questions}
return {"data": questions}
class MessageApi(Resource):
@@ -221,10 +227,7 @@ class MessageApi(Resource):
def get(self, app_model, message_id):
message_id = str(message_id)
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id
).first()
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
@@ -232,9 +235,9 @@ class MessageApi(Resource):
return message
api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
api.add_resource(MessageAnnotationApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(MessageAnnotationCountApi, '/apps/<uuid:app_id>/annotations/count')
api.add_resource(MessageApi, '/apps/<uuid:app_id>/messages/<uuid:message_id>', endpoint='console_message')
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")

View File

@@ -19,37 +19,35 @@ from services.app_model_config_service import AppModelConfigService
class ModelConfigResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
config=request.json,
app_mode=AppMode.value_of(app_model.mode)
tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode)
)
new_app_model_config = AppModelConfig(
app_id=app_model.id,
created_by=current_user.id,
updated_by=current_user.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app_model.app_model_config_id
).first()
original_app_model_config: AppModelConfig = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
)
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
for tool in agent_mode.get("tools") or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
@@ -66,7 +64,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app_model.id}'
identity_id=f"AGENT.{app_model.id}",
)
except Exception as e:
continue
@@ -79,18 +77,18 @@ class ModelConfigResource(Resource):
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
if key in tool_map:
tool_runtime = tool_map[key]
else:
@@ -108,7 +106,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app_model.id}'
identity_id=f"AGENT.{app_model.id}",
)
manager.delete_tool_parameters_cache()
@@ -116,15 +114,17 @@ class ModelConfigResource(Resource):
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
for masked_key, masked_value in masked_parameter_map[key].items():
if masked_key in agent_tool_entity.tool_parameters and \
agent_tool_entity.tool_parameters[masked_key] == masked_value:
if (
masked_key in agent_tool_entity.tool_parameters
and agent_tool_entity.tool_parameters[masked_key] == masked_value
):
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
@@ -135,12 +135,9 @@ class ModelConfigResource(Resource):
app_model.app_model_config_id = new_app_model_config.id
db.session.commit()
app_model_config_was_updated.send(
app_model,
app_model_config=new_app_model_config
)
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
return {'result': 'success'}
return {"result": "success"}
api.add_resource(ModelConfigResource, '/apps/<uuid:app_id>/model-config')
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")

View File

@@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def get(self, app_id):
parser = reqparse.RequestParser()
parser.add_argument('tracing_provider', type=str, required=True, location='args')
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
try:
trace_config = OpsService.get_tracing_app_config(
app_id=app_id, tracing_provider=args['tracing_provider']
)
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
if not trace_config:
return {"has_not_configured": True}
return trace_config
@@ -37,19 +35,17 @@ class TraceAppConfigApi(Resource):
def post(self, app_id):
"""Create a new trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument('tracing_provider', type=str, required=True, location='json')
parser.add_argument('tracing_config', type=dict, required=True, location='json')
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
args = parser.parse_args()
try:
result = OpsService.create_tracing_app_config(
app_id=app_id,
tracing_provider=args['tracing_provider'],
tracing_config=args['tracing_config']
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
)
if not result:
raise TracingConfigIsExist()
if result.get('error'):
if result.get("error"):
raise TracingConfigCheckError()
return result
except Exception as e:
@@ -61,15 +57,13 @@ class TraceAppConfigApi(Resource):
def patch(self, app_id):
"""Update an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument('tracing_provider', type=str, required=True, location='json')
parser.add_argument('tracing_config', type=dict, required=True, location='json')
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
args = parser.parse_args()
try:
result = OpsService.update_tracing_app_config(
app_id=app_id,
tracing_provider=args['tracing_provider'],
tracing_config=args['tracing_config']
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
)
if not result:
raise TracingConfigNotExist()
@@ -83,14 +77,11 @@ class TraceAppConfigApi(Resource):
def delete(self, app_id):
"""Delete an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument('tracing_provider', type=str, required=True, location='args')
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
try:
result = OpsService.delete_tracing_app_config(
app_id=app_id,
tracing_provider=args['tracing_provider']
)
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
if not result:
raise TracingConfigNotExist()
return {"result": "success"}
@@ -98,4 +89,4 @@ class TraceAppConfigApi(Resource):
raise e
api.add_resource(TraceAppConfigApi, '/apps/<uuid:app_id>/trace-config')
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")

View File

@@ -1,3 +1,5 @@
from datetime import datetime, timezone
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@@ -15,23 +17,24 @@ from models.model import Site
def parse_app_site_args():
parser = reqparse.RequestParser()
parser.add_argument('title', type=str, required=False, location='json')
parser.add_argument('icon_type', type=str, required=False, location='json')
parser.add_argument('icon', type=str, required=False, location='json')
parser.add_argument('icon_background', type=str, required=False, location='json')
parser.add_argument('description', type=str, required=False, location='json')
parser.add_argument('default_language', type=supported_language, required=False, location='json')
parser.add_argument('chat_color_theme', type=str, required=False, location='json')
parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json')
parser.add_argument('customize_domain', type=str, required=False, location='json')
parser.add_argument('copyright', type=str, required=False, location='json')
parser.add_argument('privacy_policy', type=str, required=False, location='json')
parser.add_argument('custom_disclaimer', type=str, required=False, location='json')
parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'],
required=False,
location='json')
parser.add_argument('prompt_public', type=bool, required=False, location='json')
parser.add_argument('show_workflow_steps', type=bool, required=False, location='json')
parser.add_argument("title", type=str, required=False, location="json")
parser.add_argument("icon_type", type=str, required=False, location="json")
parser.add_argument("icon", type=str, required=False, location="json")
parser.add_argument("icon_background", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("default_language", type=supported_language, required=False, location="json")
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
parser.add_argument("customize_domain", type=str, required=False, location="json")
parser.add_argument("copyright", type=str, required=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
parser.add_argument(
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
)
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args()
@@ -48,38 +51,38 @@ class AppSite(Resource):
if not current_user.is_editor:
raise Forbidden()
site = db.session.query(Site). \
filter(Site.app_id == app_model.id). \
one_or_404()
site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404()
for attr_name in [
'title',
'icon_type',
'icon',
'icon_background',
'description',
'default_language',
'chat_color_theme',
'chat_color_theme_inverted',
'customize_domain',
'copyright',
'privacy_policy',
'custom_disclaimer',
'customize_token_strategy',
'prompt_public',
'show_workflow_steps'
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
return site
class AppSiteAccessTokenReset(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -96,10 +99,12 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound
site.code = Site.generate_code(16)
site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
return site
api.add_resource(AppSite, '/apps/<uuid:app_id>/site')
api.add_resource(AppSiteAccessTokenReset, '/apps/<uuid:app_id>/site/access-token-reset')
api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")

View File

@@ -16,8 +16,7 @@ from libs.login import login_required
from models.model import AppMode
class DailyConversationStatistic(Resource):
class DailyMessageStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -26,58 +25,52 @@ class DailyConversationStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
FROM messages where app_id = :app_id
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'conversation_count': i.conversation_count
})
response_data.append({"date": str(i.date), "message_count": i.message_count})
return jsonify({
'data': response_data
})
return jsonify({"data": response_data})
class DailyTerminalsStatistic(Resource):
class DailyConversationStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -86,54 +79,103 @@ class DailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
FROM messages where app_id = :app_id
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'terminal_count': i.terminal_count
})
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
return jsonify({
'data': response_data
})
return jsonify({"data": response_data})
class DailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
return jsonify({"data": response_data})
class DailyTokenCostStatistic(Resource):
@@ -145,58 +187,53 @@ class DailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
sum(total_price) as total_price
FROM messages where app_id = :app_id
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'token_count': i.token_count,
'total_price': i.total_price,
'currency': 'USD'
})
response_data.append(
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
)
return jsonify({
'data': response_data
})
return jsonify({"data": response_data})
class AverageSessionInteractionStatistic(Resource):
@@ -208,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
@@ -218,30 +255,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
FROM conversations c
JOIN messages m ON c.id = m.conversation_id
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and c.created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and c.created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and c.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += """
GROUP BY m.conversation_id) subquery
@@ -250,18 +287,15 @@ GROUP BY date
ORDER BY date"""
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'interactions': float(i.interactions.quantize(Decimal('0.01')))
})
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
)
return jsonify({
'data': response_data
})
return jsonify({"data": response_data})
class UserSatisfactionRateStatistic(Resource):
@@ -273,57 +307,57 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''
sql_query = """
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
FROM messages m
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
WHERE m.app_id = :app_id
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and m.created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and m.created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and m.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
})
response_data.append(
{
"date": str(i.date),
"rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
}
)
return jsonify({
'data': response_data
})
return jsonify({"data": response_data})
class AverageResponseTimeStatistic(Resource):
@@ -335,56 +369,51 @@ class AverageResponseTimeStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) as latency
FROM messages
WHERE app_id = :app_id
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'latency': round(i.latency * 1000, 4)
})
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
return jsonify({
'data': response_data
})
return jsonify({"data": response_data})
class TokensPerSecondStatistic(Resource):
@@ -396,63 +425,59 @@ class TokensPerSecondStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
END as tokens_per_second
FROM messages
WHERE app_id = :app_id'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
WHERE app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'tps': round(i.tokens_per_second, 4)
})
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
return jsonify({
'data': response_data
})
return jsonify({"data": response_data})
api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations')
api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users')
api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs')
api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions')
api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate')
api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time')
api.add_resource(TokensPerSecondStatistic, '/apps/<uuid:app_id>/statistics/tokens-per-second')
api.add_resource(DailyMessageStatistic, "/apps/<uuid:app_id>/statistics/daily-messages")
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions")
api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time")
api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")

View File

@@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
content_type = request.headers.get('Content-Type', '')
if 'application/json' in content_type:
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
parser = reqparse.RequestParser()
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
parser.add_argument('hash', type=str, required=False, location='json')
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json")
# TODO: set this to required=True after frontend is updated
parser.add_argument('environment_variables', type=list, required=False, location='json')
parser.add_argument('conversation_variables', type=list, required=False, location='json')
parser.add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
args = parser.parse_args()
elif 'text/plain' in content_type:
elif "text/plain" in content_type:
try:
data = json.loads(request.data.decode('utf-8'))
if 'graph' not in data or 'features' not in data:
raise ValueError('graph or features not found in data')
data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict):
raise ValueError('graph or features is not a dict')
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
raise ValueError("graph or features is not a dict")
args = {
'graph': data.get('graph'),
'features': data.get('features'),
'hash': data.get('hash'),
'environment_variables': data.get('environment_variables'),
'conversation_variables': data.get('conversation_variables'),
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
}
except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
workflow_service = WorkflowService()
try:
environment_variables_list = args.get('environment_variables') or []
environment_variables_list = args.get("environment_variables") or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = args.get('conversation_variables') or []
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args['graph'],
features=args['features'],
unique_hash=args.get('hash'),
graph=args["graph"],
features=args["features"],
unique_hash=args.get("hash"),
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
@@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource):
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@@ -138,13 +138,11 @@ class DraftWorkflowImportApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
workflow = AppDslService.import_and_overwrite_workflow(
app_model=app_model,
data=args['data'],
account=current_user
app_model=app_model, data=args["data"], account=current_user
)
return workflow
@@ -162,21 +160,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, location='json')
parser.add_argument('query', type=str, required=True, location='json', default='')
parser.add_argument('files', type=list, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument("inputs", type=dict, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("files", type=list, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
)
return helper.compact_generate_response(response)
@@ -190,6 +184,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
class AdvancedChatDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@@ -202,18 +197,14 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, location='json')
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
@@ -227,6 +218,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@@ -239,18 +231,14 @@ class WorkflowDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, location='json')
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
@@ -264,6 +252,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
class DraftWorkflowRunApi(Resource):
@setup_required
@login_required
@@ -276,19 +265,15 @@ class DraftWorkflowRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
)
return helper.compact_generate_response(response)
@@ -311,12 +296,10 @@ class WorkflowTaskStopApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {
"result": "success"
}
return {"result": "success"}
class DraftWorkflowNodeRunApi(Resource):
@@ -332,24 +315,20 @@ class DraftWorkflowNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model,
node_id=node_id,
user_inputs=args.get('inputs'),
account=current_user
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
)
return workflow_node_execution
class PublishedWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -362,7 +341,7 @@ class PublishedWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# fetch published workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model)
@@ -381,14 +360,11 @@ class PublishedWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
return {
"result": "success",
"created_at": TimestampField().format(workflow.created_at)
}
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
class DefaultBlockConfigsApi(Resource):
@@ -403,7 +379,7 @@ class DefaultBlockConfigsApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_configs()
@@ -421,24 +397,21 @@ class DefaultBlockConfigApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('q', type=str, location='args')
parser.add_argument("q", type=str, location="args")
args = parser.parse_args()
filters = None
if args.get('q'):
if args.get("q"):
try:
filters = json.loads(args.get('q'))
filters = json.loads(args.get("q"))
except json.JSONDecodeError:
raise ValueError('Invalid filters')
raise ValueError("Invalid filters")
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_config(
node_type=block_type,
filters=filters
)
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
class ConvertToWorkflowApi(Resource):
@@ -455,41 +428,43 @@ class ConvertToWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if request.data:
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
else:
args = {}
# convert to workflow mode
workflow_service = WorkflowService()
new_app_model = workflow_service.convert_to_workflow(
app_model=app_model,
account=current_user,
args=args
)
new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args)
# return app id
return {
'new_app_id': new_app_model.id,
"new_app_id": new_app_model.id,
}
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
api.add_resource(DraftWorkflowImportApi, '/apps/<uuid:app_id>/workflows/draft/import')
api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'
'/<string:block_type>')
api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
)
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
)
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

View File

@@ -22,20 +22,19 @@ class WorkflowAppLogApi(Resource):
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args')
parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args')
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model,
args=args
app_model=app_model, args=args
)
return workflow_app_log_pagination
api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs')
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")

View File

@@ -28,15 +28,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
Get advanced chat app workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model,
args=args
)
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
return result
@@ -52,15 +49,12 @@ class WorkflowRunListApi(Resource):
Get workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model,
args=args
)
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
return result
@@ -98,12 +92,10 @@ class WorkflowRunNodeExecutionListApi(Resource):
workflow_run_service = WorkflowRunService()
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
return {
'data': node_executions
}
return {"data": node_executions}
api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps/<uuid:app_id>/advanced-chat/workflow-runs')
api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs')
api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>')
api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions')
api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps/<uuid:app_id>/advanced-chat/workflow-runs")
api.add_resource(WorkflowRunListApi, "/apps/<uuid:app_id>/workflow-runs")
api.add_resource(WorkflowRunDetailApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
api.add_resource(WorkflowRunNodeExecutionListApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")

View File

@@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'runs': i.runs
})
response_data.append({"date": str(i.date), "runs": i.runs})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class WorkflowDailyTerminalsStatistic(Resource):
@setup_required
@@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'terminal_count': i.terminal_count
})
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class WorkflowDailyTokenCostStatistic(Resource):
@setup_required
@@ -146,58 +146,63 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = '''
sql_query = """
SELECT
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) as token_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' GROUP BY date order by date'
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'token_count': i.token_count,
})
response_data.append(
{
"date": str(i.date),
"token_count": i.token_count,
}
)
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class WorkflowAverageAppInteractionStatistic(Resource):
@setup_required
@@ -208,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@@ -229,50 +234,54 @@ class WorkflowAverageAppInteractionStatistic(Resource):
GROUP BY date, c.created_by) sub
GROUP BY sub.date
"""
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start')
arg_dict['start'] = start_datetime_utc
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
arg_dict["start"] = start_datetime_utc
else:
sql_query = sql_query.replace('{{start}}', '')
sql_query = sql_query.replace("{{start}}", "")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end')
arg_dict['end'] = end_datetime_utc
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
arg_dict["end"] = end_datetime_utc
else:
sql_query = sql_query.replace('{{end}}', '')
sql_query = sql_query.replace("{{end}}", "")
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({
'date': str(i.date),
'interactions': float(i.interactions.quantize(Decimal('0.01')))
})
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
)
return jsonify({
'data': response_data
})
return jsonify({"data": response_data})
api.add_resource(WorkflowDailyRunsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-conversations')
api.add_resource(WorkflowDailyTerminalsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-terminals')
api.add_resource(WorkflowDailyTokenCostStatistic, '/apps/<uuid:app_id>/workflow/statistics/token-costs')
api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps/<uuid:app_id>/workflow/statistics/average-app-interactions')
api.add_resource(WorkflowDailyRunsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
api.add_resource(WorkflowDailyTerminalsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
api.add_resource(WorkflowDailyTokenCostStatistic, "/apps/<uuid:app_id>/workflow/statistics/token-costs")
api.add_resource(
WorkflowAverageAppInteractionStatistic, "/apps/<uuid:app_id>/workflow/statistics/average-app-interactions"
)

View File

@@ -8,24 +8,23 @@ from libs.login import current_user
from models.model import App, AppMode
def get_app_model(view: Optional[Callable] = None, *,
mode: Union[AppMode, list[AppMode]] = None):
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get('app_id'):
raise ValueError('missing app_id in path parameters')
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
app_id = kwargs.get('app_id')
app_id = kwargs.get("app_id")
app_id = str(app_id)
del kwargs['app_id']
del kwargs["app_id"]
app_model = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app_model:
raise AppNotFoundError()
@@ -44,9 +43,10 @@ def get_app_model(view: Optional[Callable] = None, *,
mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs['app_model'] = app_model
kwargs["app_model"] = app_model
return view_func(*args, **kwargs)
return decorated_view
if view is None:

View File

@@ -17,60 +17,61 @@ from services.account_service import RegisterService
class ActivateCheckApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
parser.add_argument('email', type=email, required=False, nullable=True, location='args')
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
workspaceId = args['workspace_id']
reg_email = args['email']
token = args['token']
workspaceId = args["workspace_id"]
reg_email = args["email"]
token = args["token"]
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
class ActivateApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
parser.add_argument('email', type=email, required=False, nullable=True, location='json')
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
location='json')
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
args = parser.parse_args()
invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
account = invitation['account']
account.name = args['name']
account = invitation["account"]
account.name = args["name"]
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(args['password'], salt)
password_hashed = hash_password(args["password"], salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
account.interface_language = args['interface_language']
account.timezone = args['timezone']
account.interface_theme = 'light'
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
return {'result': 'success'}
return {"result": "success"}
api.add_resource(ActivateCheckApi, '/activate/check')
api.add_resource(ActivateApi, '/activate')
api.add_resource(ActivateCheckApi, "/activate/check")
api.add_resource(ActivateApi, "/activate")

View File

@@ -19,18 +19,19 @@ class ApiKeyAuthDataSource(Resource):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
if data_source_api_key_bindings:
return {
'sources': [{
'id': data_source_api_key_binding.id,
'category': data_source_api_key_binding.category,
'provider': data_source_api_key_binding.provider,
'disabled': data_source_api_key_binding.disabled,
'created_at': int(data_source_api_key_binding.created_at.timestamp()),
'updated_at': int(data_source_api_key_binding.updated_at.timestamp()),
}
for data_source_api_key_binding in
data_source_api_key_bindings]
"sources": [
{
"id": data_source_api_key_binding.id,
"category": data_source_api_key_binding.category,
"provider": data_source_api_key_binding.provider,
"disabled": data_source_api_key_binding.disabled,
"created_at": int(data_source_api_key_binding.created_at.timestamp()),
"updated_at": int(data_source_api_key_binding.updated_at.timestamp()),
}
for data_source_api_key_binding in data_source_api_key_bindings
]
}
return {'sources': []}
return {"sources": []}
class ApiKeyAuthDataSourceBinding(Resource):
@@ -42,16 +43,16 @@ class ApiKeyAuthDataSourceBinding(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {'result': 'success'}, 200
return {"result": "success"}, 200
class ApiKeyAuthDataSourceBindingDelete(Resource):
@@ -65,9 +66,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")

View File

@@ -17,13 +17,13 @@ from ..wraps import account_initialization_required
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
client_secret=dify_config.NOTION_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
notion_oauth = NotionOAuth(
client_id=dify_config.NOTION_CLIENT_ID,
client_secret=dify_config.NOTION_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
)
OAUTH_PROVIDERS = {
'notion': notion_oauth
}
OAUTH_PROVIDERS = {"notion": notion_oauth}
return OAUTH_PROVIDERS
@@ -37,18 +37,16 @@ class OAuthDataSource(Resource):
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
print(vars(oauth_provider))
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
return {"error": "Invalid provider"}, 400
if dify_config.NOTION_INTEGRATION_TYPE == "internal":
internal_secret = dify_config.NOTION_INTERNAL_SECRET
if not internal_secret:
return {'error': 'Internal secret is not set'},
return ({"error": "Internal secret is not set"},)
oauth_provider.save_internal_access_token(internal_secret)
return { 'data': '' }
return {"data": ""}
else:
auth_url = oauth_provider.get_authorization_url()
return { 'data': auth_url }, 200
return {"data": auth_url}, 200
class OAuthDataSourceCallback(Resource):
@@ -57,18 +55,18 @@ class OAuthDataSourceCallback(Resource):
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
if 'code' in request.args:
code = request.args.get('code')
return {"error": "Invalid provider"}, 400
if "code" in request.args:
code = request.args.get("code")
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}")
elif "error" in request.args:
error = request.args.get("error")
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}")
else:
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
class OAuthDataSourceBinding(Resource):
def get(self, provider: str):
@@ -76,17 +74,18 @@ class OAuthDataSourceBinding(Resource):
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
if 'code' in request.args:
code = request.args.get('code')
return {"error": "Invalid provider"}, 400
if "code" in request.args:
code = request.args.get("code")
try:
oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e:
logging.exception(
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {'error': 'OAuth data source process failed'}, 400
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}"
)
return {"error": "OAuth data source process failed"}, 400
return {'result': 'success'}, 200
return {"result": "success"}, 200
class OAuthDataSourceSync(Resource):
@@ -100,18 +99,17 @@ class OAuthDataSourceSync(Resource):
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id)
except requests.exceptions.HTTPError as e:
logging.exception(
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {'error': 'OAuth data source process failed'}, 400
logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {"error": "OAuth data source process failed"}, 400
return {'result': 'success'}, 200
return {"result": "success"}, 200
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")

View File

@@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException
class ApiKeyAuthFailedError(BaseHTTPException):
error_code = 'auth_failed'
error_code = "auth_failed"
description = "{message}"
code = 500
class InvalidEmailError(BaseHTTPException):
error_code = 'invalid_email'
error_code = "invalid_email"
description = "The email address is not valid."
code = 400
class PasswordMismatchError(BaseHTTPException):
error_code = 'password_mismatch'
error_code = "password_mismatch"
description = "The passwords do not match."
code = 400
class InvalidTokenError(BaseHTTPException):
error_code = 'invalid_or_expired_token'
error_code = "invalid_or_expired_token"
description = "The token is invalid or has expired."
code = 400
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = 'password_reset_rate_limit_exceeded'
error_code = "password_reset_rate_limit_exceeded"
description = "Password reset rate limit exceeded. Try again later."
code = 429

View File

@@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError
class ForgotPasswordSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('email', type=str, required=True, location='json')
parser.add_argument("email", type=str, required=True, location="json")
args = parser.parse_args()
email = args['email']
email = args["email"]
if not email_validate(email):
raise InvalidEmailError()
@@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
token = args['token']
token = args["token"]
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
return {'is_valid': False, 'email': None}
return {'is_valid': True, 'email': reset_data.get('email')}
return {"is_valid": False, "email": None}
return {"is_valid": True, "email": reset_data.get("email")}
class ForgotPasswordResetApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
new_password = args['new_password']
password_confirm = args['password_confirm']
new_password = args["new_password"]
password_confirm = args["password_confirm"]
if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError()
token = args['token']
token = args["token"]
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
@@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get('email')).first()
account = Account.query.filter_by(email=reset_data.get("email")).first()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
return {'result': 'success'}
return {"result": "success"}
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@@ -20,37 +20,39 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument('email', type=email, required=True, location='json')
parser.add_argument('password', type=valid_password, required=True, location='json')
parser.add_argument('remember_me', type=bool, required=False, default=False, location='json')
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
args = parser.parse_args()
# todo: Verify the recaptcha
try:
account = AccountService.authenticate(args['email'], args['password'])
account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError as e:
return {'code': 'unauthorized', 'message': str(e)}, 401
return {"code": "unauthorized", "message": str(e)}, 401
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
return {
"result": "fail",
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
token = AccountService.login(account, ip_address=get_remote_ip(request))
return {'result': 'success', 'data': token}
return {"result": "success", "data": token}
class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
token = request.headers.get('Authorization', '').split(' ')[1]
token = request.headers.get("Authorization", "").split(" ")[1]
AccountService.logout(account=account, token=token)
flask_login.logout_user()
return {'result': 'success'}
return {"result": "success"}
class ResetPasswordApi(Resource):
@@ -80,11 +82,11 @@ class ResetPasswordApi(Resource):
# 'subject': 'Reset your Dify password',
# 'html': """
# <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p>
# <p>The Dify Team</p>
# <p>The Dify Team</p>
# """
# }
@@ -101,8 +103,8 @@ class ResetPasswordApi(Resource):
# # handle error
# pass
return {'result': 'success'}
return {"result": "success"}
api.add_resource(LoginApi, '/login')
api.add_resource(LogoutApi, '/logout')
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")

View File

@@ -25,7 +25,7 @@ def get_oauth_providers():
github_oauth = GitHubOAuth(
client_id=dify_config.GITHUB_CLIENT_ID,
client_secret=dify_config.GITHUB_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
)
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
google_oauth = None
@@ -33,10 +33,10 @@ def get_oauth_providers():
google_oauth = GoogleOAuth(
client_id=dify_config.GOOGLE_CLIENT_ID,
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
)
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
return OAUTH_PROVIDERS
@@ -47,7 +47,7 @@ class OAuthLogin(Resource):
oauth_provider = OAUTH_PROVIDERS.get(provider)
print(vars(oauth_provider))
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url()
return redirect(auth_url)
@@ -59,20 +59,20 @@ class OAuthCallback(Resource):
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
return {"error": "Invalid provider"}, 400
code = request.args.get('code')
code = request.args.get("code")
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e:
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
return {'error': 'OAuth process failed'}, 400
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400
account = _generate_account(provider, user_info)
# Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
return {'error': 'Account is banned or closed.'}, 403
return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
@@ -83,7 +83,7 @@ class OAuthCallback(Resource):
token = AccountService.login(account, ip_address=get_remote_ip(request))
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
@@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account:
# Create account
account_name = user_info.name if user_info.name else 'Dify'
account_name = user_info.name if user_info.name else "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)
@@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account
api.add_resource(OAuthLogin, '/oauth/login/<provider>')
api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")

View File

@@ -9,28 +9,24 @@ from services.billing_service import BillingService
class Subscription(Resource):
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args['plan'],
args['interval'],
current_user.email,
current_user.current_tenant_id)
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
)
class Invoices(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -40,5 +36,5 @@ class Invoices(Resource):
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, '/billing/invoices')
api.add_resource(Subscription, "/billing/subscription")
api.add_resource(Invoices, "/billing/invoices")

View File

@@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
# get workspace data source integrates
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False
).all()
data_source_integrates = (
db.session.query(DataSourceOauthBinding)
.filter(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
.all()
)
base_url = request.url_root.rstrip('/')
base_url = request.url_root.rstrip("/")
data_source_oauth_base_path = "/console/api/oauth/data-source"
providers = ["notion"]
@@ -44,26 +47,30 @@ class DataSourceApi(Resource):
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
if existing_integrates:
for existing_integrate in list(existing_integrates):
integrate_data.append({
'id': existing_integrate.id,
'provider': provider,
'created_at': existing_integrate.created_at,
'is_bound': True,
'disabled': existing_integrate.disabled,
'source_info': existing_integrate.source_info,
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
})
integrate_data.append(
{
"id": existing_integrate.id,
"provider": provider,
"created_at": existing_integrate.created_at,
"is_bound": True,
"disabled": existing_integrate.disabled,
"source_info": existing_integrate.source_info,
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
}
)
else:
integrate_data.append({
'id': None,
'provider': provider,
'created_at': None,
'source_info': None,
'is_bound': False,
'disabled': None,
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
})
return {'data': integrate_data}, 200
integrate_data.append(
{
"id": None,
"provider": provider,
"created_at": None,
"source_info": None,
"is_bound": False,
"disabled": None,
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
}
)
return {"data": integrate_data}, 200
@setup_required
@login_required
@@ -71,92 +78,82 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
data_source_binding = DataSourceOauthBinding.query.filter_by(
id=binding_id
).first()
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
if data_source_binding is None:
raise NotFound('Data source binding not found.')
raise NotFound("Data source binding not found.")
# enable binding
if action == 'enable':
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError('Data source is not disabled.')
raise ValueError("Data source is not disabled.")
# disable binding
if action == 'disable':
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError('Data source is disabled.')
return {'result': 'success'}, 200
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200
class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self):
dataset_id = request.args.get('dataset_id', default=None, type=str)
dataset_id = request.args.get("dataset_id", default=None, type=str)
exist_page_ids = []
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
if dataset.data_source_type != 'notion_import':
raise ValueError('Dataset is not notion type.')
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = Document.query.filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type='notion_import',
enabled=True
data_source_type="notion_import",
enabled=True,
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info['notion_page_id'])
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id,
provider='notion',
disabled=False
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
).all()
if not data_source_bindings:
return {
'notion_info': []
}, 200
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info['pages']
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page['page_id'] in exist_page_ids:
page['is_bound'] = True
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page['is_bound'] = False
page["is_bound"] = False
pre_import_info = {
'workspace_name': source_info['workspace_name'],
'workspace_icon': source_info['workspace_icon'],
'workspace_id': source_info['workspace_id'],
'pages': pages,
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {
'notion_info': pre_import_info_list
}, 200
return {"notion_info": pre_import_info_list}, 200
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource):
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
if not data_source_binding:
raise NotFound('Data source binding not found.')
raise NotFound("Data source binding not found.")
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id
tenant_id=current_user.current_tenant_id,
)
text_docs = extractor.extract()
return {
'content': "\n".join([doc.page_content for doc in text_docs])
}, 200
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
notion_info_list = args['notion_info_list']
notion_info_list = args["notion_info_list"]
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
document_model=args['doc_form']
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'])
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
args["doc_language"],
)
return response, 200
class DataSourceNotionDatasetSyncApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource):
class DataSourceNotionDocumentSyncApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource):
return 200
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
api.add_resource(DataSourceNotionApi,
'/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
'/datasets/notion-indexing-estimate')
api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
api.add_resource(
DataSourceNotionApi,
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
api.add_resource(
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
)

View File

@@ -31,45 +31,40 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 40 characters.')
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError('Description cannot exceed 400 characters.')
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
ids = request.args.getlist('ids')
provider = request.args.get('provider', default="vendor")
search = request.args.get('keyword', default=None, type=str)
tag_ids = request.args.getlist('tag_ids')
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user, search, tag_ids)
datasets, total = DatasetService.get_datasets(
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
@@ -77,28 +72,22 @@ class DatasetListApi(Resource):
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
if item["indexing_technique"] == "high_quality":
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
item["embedding_available"] = True
else:
item['embedding_available'] = False
item["embedding_available"] = False
else:
item['embedding_available'] = True
item["embedding_available"] = True
if item.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
item.update({'partial_member_list': part_users_list})
if item.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
item.update({"partial_member_list": part_users_list})
else:
item.update({'partial_member_list': []})
item.update({"partial_member_list": []})
response = {
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
'page': page
}
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
@setup_required
@@ -106,13 +95,21 @@ class DatasetListApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help='Invalid indexing technique.')
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -122,9 +119,10 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
name=args['name'],
indexing_technique=args['indexing_technique'],
account=current_user
name=args["name"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -142,42 +140,36 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(
dataset, current_user)
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get('permission') == 'partial_members':
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data['indexing_technique'] == 'high_quality':
if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data['embedding_available'] = True
data["embedding_available"] = True
else:
data['embedding_available'] = False
data["embedding_available"] = False
else:
data['embedding_available'] = True
data["embedding_available"] = True
if data.get('permission') == 'partial_members':
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
data.update({"partial_member_list": part_users_list})
return data, 200
@@ -191,42 +183,49 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('description',
location='json', store_missing=False,
type=_validate_description_length)
parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.'
)
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
args = parser.parse_args()
data = request.get_json()
# check embedding model setting
if data.get('indexing_technique') == 'high_quality':
DatasetService.check_embedding_model_setting(dataset.tenant_id,
data.get('embedding_model_provider'),
data.get('embedding_model')
)
if data.get("indexing_technique") == "high_quality":
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get('permission'), data.get('partial_member_list')
current_user, dataset, data.get("permission"), data.get("partial_member_list")
)
dataset = DatasetService.update_dataset(
dataset_id_str, args, current_user)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@@ -234,16 +233,19 @@ class DatasetApi(Resource):
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get('partial_member_list')
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM:
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({'partial_member_list': partial_member_list})
result_data.update({"partial_member_list": partial_member_list})
return result_data, 200
@@ -260,12 +262,13 @@ class DatasetApi(Resource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {'result': 'success'}, 204
return {"result": "success"}, 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
raise DatasetInUseError()
class DatasetUseCheckApi(Resource):
@setup_required
@login_required
@@ -274,10 +277,10 @@ class DatasetUseCheckApi(Resource):
dataset_id_str = str(dataset_id)
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
return {'is_using': dataset_is_using}, 200
return {"is_using": dataset_is_using}, 200
class DatasetQueryApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -292,51 +295,53 @@ class DatasetQueryApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset_queries, total = DatasetService.get_dataset_queries(
dataset_id=dataset.id,
page=page,
per_page=limit
)
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
response = {
'data': marshal(dataset_queries, dataset_query_detail_fields),
'has_more': len(dataset_queries) == limit,
'limit': limit,
'total': total,
'page': page
"data": marshal(dataset_queries, dataset_query_detail_fields),
"has_more": len(dataset_queries) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
class DatasetIndexingEstimateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args['info_list']['data_source_type'] == 'upload_file':
file_ids = args['info_list']['file_info_list']['file_ids']
file_details = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id.in_(file_ids)
).all()
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all()
)
if file_details is None:
raise NotFound("File not found.")
@@ -344,55 +349,58 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model=args['doc_form']
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
)
extract_settings.append(extract_setting)
elif args['info_list']['data_source_type'] == 'notion_import':
notion_info_list = args['info_list']['notion_info_list']
elif args["info_list"]["data_source_type"] == "notion_import":
notion_info_list = args["info_list"]["notion_info_list"]
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
document_model=args['doc_form']
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
elif args['info_list']['data_source_type'] == 'website_crawl':
website_info_list = args['info_list']['website_info_list']
for url in website_info_list['urls']:
elif args["info_list"]["data_source_type"] == "website_crawl":
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": website_info_list['provider'],
"job_id": website_info_list['job_id'],
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": 'crawl',
"only_main_content": website_info_list['only_main_content']
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
},
document_model=args['doc_form']
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
args["doc_language"],
args["dataset_id"],
args["indexing_technique"],
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
@@ -402,7 +410,6 @@ class DatasetIndexingEstimateApi(Resource):
class DatasetRelatedAppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -426,52 +433,52 @@ class DatasetRelatedAppListApi(Resource):
if app_model:
related_apps.append(app_model)
return {
'data': related_apps,
'total': len(related_apps)
}, 200
return {"data": related_apps, "total": len(related_apps)}, 200
class DatasetIndexingStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id = str(dataset_id)
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id
).all()
documents = (
db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all()
)
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields))
data = {
'data': documents_status
}
data = {"data": documents_status}
return data
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = 'dataset-'
resource_type = 'dataset'
token_prefix = "dataset-"
resource_type = "dataset"
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
keys = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
all()
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all()
)
return {"items": keys}
@setup_required
@@ -483,15 +490,17 @@ class DatasetApiKeyApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
count()
current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.count()
)
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code='max_keys_exceeded'
code="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
@@ -505,7 +514,7 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource):
resource_type = 'dataset'
resource_type = "dataset"
@setup_required
@login_required
@@ -517,18 +526,23 @@ class DatasetApiDeleteApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
key = db.session.query(ApiToken). \
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id). \
first()
key = (
db.session.query(ApiToken)
.filter(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
)
if key is None:
flask_restful.abort(404, message='API key not found')
flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {'result': 'success'}, 204
return {"result": "success"}, 204
class DatasetApiBaseUrlApi(Resource):
@@ -537,8 +551,10 @@ class DatasetApiBaseUrlApi(Resource):
@account_initialization_required
def get(self):
return {
'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
else request.host_url.rstrip('/')) + '/v1'
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
}
@@ -549,15 +565,26 @@ class DatasetRetrievalSettingApi(Resource):
def get(self):
vector_type = dify_config.VECTOR_STORE
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
):
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return {
'retrieval_method': [
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
@@ -573,15 +600,27 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
):
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
return {
'retrieval_method': [
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
@@ -591,7 +630,6 @@ class DatasetRetrievalSettingMockApi(Resource):
raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetErrorDocs(Resource):
@setup_required
@login_required
@@ -603,10 +641,7 @@ class DatasetErrorDocs(Resource):
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return {
'data': [marshal(item, document_status_fields) for item in results],
'total': len(results)
}, 200
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
class DatasetPermissionUserListApi(Resource):
@@ -626,21 +661,21 @@ class DatasetPermissionUserListApi(Resource):
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
'data': partial_members_list,
"data": partial_members_list,
}, 200
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")

File diff suppressed because it is too large Load Diff

View File

@@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource):
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
@@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
raise NotFound("Document not found.")
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=str, default=None, location='args')
parser.add_argument('limit', type=int, default=20, location='args')
parser.add_argument('status', type=str,
action='append', default=[], location='args')
parser.add_argument('hit_count_gte', type=int,
default=None, location='args')
parser.add_argument('enabled', type=str, default='all', location='args')
parser.add_argument('keyword', type=str, default=None, location='args')
parser.add_argument("last_id", type=str, default=None, location="args")
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
args = parser.parse_args()
last_id = args['last_id']
limit = min(args['limit'], 100)
status_list = args['status']
hit_count_gte = args['hit_count_gte']
keyword = args['keyword']
last_id = args["last_id"]
limit = min(args["limit"], 100)
status_list = args["status"]
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
)
if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(
DocumentSegment.position > last_segment.position)
query = query.filter(DocumentSegment.position > last_segment.position)
else:
return {'data': [], 'has_more': False, 'limit': limit}, 200
return {"data": [], "has_more": False, "limit": limit}, 200
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
@@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args['enabled'].lower() != 'all':
if args['enabled'].lower() == 'true':
if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
query = query.filter(DocumentSegment.enabled == True)
elif args['enabled'].lower() == 'false':
elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False)
total = query.count()
@@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource):
segments = segments[:-1]
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'has_more': has_more,
'limit': limit,
'total': total
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"has_more": has_more,
"limit": limit,
"total": total,
}, 200
@@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, segment_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor
@@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager()
@@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource):
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
raise NotFound("Segment not found.")
if segment.status != 'completed':
raise NotFound('Segment is not completed, enable or disable function is not allowed')
if segment.status != "completed":
raise NotFound("Segment is not completed, enable or disable function is not allowed")
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
indexing_cache_key = "segment_{}_indexing".format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is being indexed, please try again later")
@@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
enable_segment_to_index_task.delay(segment.id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")
@@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource):
disable_segment_from_index_task.delay(segment.id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
else:
raise InvalidActionError()
@@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check('add_segment')
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
raise NotFound("Document not found.")
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
@@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource):
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
if dataset.indexing_technique == 'high_quality':
raise NotFound("Document not found.")
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager()
@@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
@@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@login_required
@@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
@@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
return {"result": "success"}, 200
class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check('add_segment')
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
raise NotFound("Document not found.")
# get file from request
file = request.files['file']
file = request.files["file"]
# check file
if 'file' not in request.files:
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
if not file.filename.endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
@@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
if document.doc_form == 'qa_model':
data = {'content': row[0], 'answer': row[1]}
if document.doc_form == "qa_model":
data = {"content": row[0], "answer": row[1]}
else:
data = {'content': row[0]}
data = {"content": row[0]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
current_user.current_tenant_id, current_user.id)
redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay(
str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
)
except Exception as e:
return {'error': str(e)}, 500
return {
'job_id': job_id,
'job_status': 'waiting'
}, 200
return {"error": str(e)}, 500
return {"job_id": job_id, "job_status": "waiting"}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, job_id):
job_id = str(job_id)
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
indexing_cache_key = "segment_batch_import_{}".format(job_id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
return {
'job_id': job_id,
'job_status': cache_result.decode()
}, 200
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentSegmentBatchImportApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
'/datasets/batch_import_status/<uuid:job_id>')
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource(
DatasetDocumentSegmentUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
)
api.add_resource(
DatasetDocumentSegmentBatchImportApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)

View File

@@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = 'file_too_large'
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = 'high_quality_dataset_only'
error_code = "high_quality_dataset_only"
description = "Current operation only supports 'high-quality' datasets."
code = 400
class DatasetNotInitializedError(BaseHTTPException):
error_code = 'dataset_not_initialized'
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 400
class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = 'archived_document_immutable'
error_code = "archived_document_immutable"
description = "The archived document is not editable."
code = 403
class DatasetNameDuplicateError(BaseHTTPException):
error_code = 'dataset_name_duplicate'
error_code = "dataset_name_duplicate"
description = "The dataset name already exists. Please modify your dataset name."
code = 409
class InvalidActionError(BaseHTTPException):
error_code = 'invalid_action'
error_code = "invalid_action"
description = "Invalid action."
code = 400
class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = 'document_already_finished'
error_code = "document_already_finished"
description = "The document has been processed. Please refresh the page or go to the document details."
code = 400
class DocumentIndexingError(BaseHTTPException):
error_code = 'document_indexing'
error_code = "document_indexing"
description = "The document is being processed and cannot be edited."
code = 400
class InvalidMetadataError(BaseHTTPException):
error_code = 'invalid_metadata'
error_code = "invalid_metadata"
description = "The metadata content is incorrect. Please check and verify."
code = 400
class WebsiteCrawlError(BaseHTTPException):
error_code = 'crawl_failed'
error_code = "crawl_failed"
description = "{message}"
code = 500
class DatasetInUseError(BaseHTTPException):
error_code = 'dataset_in_use'
error_code = "dataset_in_use"
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
code = 409
class IndexingEstimateError(BaseHTTPException):
error_code = 'indexing_estimate_error'
error_code = "indexing_estimate_error"
description = "Knowledge indexing estimate failed: {message}"
code = 500

View File

@@ -21,7 +21,6 @@ PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -31,23 +30,22 @@ class FileApi(Resource):
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
return {
'file_size_limit': file_size_limit,
'batch_count_limit': batch_count_limit,
'image_file_size_limit': image_file_size_limit
"file_size_limit": file_size_limit,
"batch_count_limit": batch_count_limit,
"image_file_size_limit": image_file_size_limit,
}, 200
@setup_required
@login_required
@account_initialization_required
@marshal_with(file_fields)
@cloud_edition_billing_resource_check(resource='documents')
@cloud_edition_billing_resource_check("documents")
def post(self):
# get file from request
file = request.files['file']
file = request.files["file"]
# check file
if 'file' not in request.files:
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
@@ -69,7 +67,7 @@ class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
text = FileService.get_file_preview(file_id)
return {'content': text}
return {"content": text}
class FileSupportTypeApi(Resource):
@@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource):
@account_initialization_required
def get(self):
etl_type = dify_config.ETL_TYPE
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
return {'allowed_extensions': allowed_extensions}
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
return {"allowed_extensions": allowed_extensions}
api.add_resource(FileApi, '/files/upload')
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
api.add_resource(FileSupportTypeApi, '/files/support-type')
api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type")

View File

@@ -29,7 +29,6 @@ from services.hit_testing_service import HitTestingService
class HitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -46,8 +45,8 @@ class HitTestingApi(Resource):
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, location='json')
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@@ -55,13 +54,13 @@ class HitTestingApi(Resource):
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args['query'],
query=args["query"],
account=current_user,
retrieval_model=args['retrieval_model'],
limit=10
retrieval_model=args["retrieval_model"],
limit=10,
)
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
@@ -73,7 +72,8 @@ class HitTestingApi(Resource):
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
@@ -83,4 +83,4 @@ class HitTestingApi(Resource):
raise InternalServerError(str(e))
api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

View File

@@ -9,16 +9,14 @@ from services.website_service import WebsiteService
class WebsiteCrawlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, choices=['firecrawl'],
required=True, nullable=True, location='json')
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
WebsiteService.document_create_args_validate(args)
# crawl url
@@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource):
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
args = parser.parse_args()
# get crawl status
try:
result = WebsiteService.get_crawl_status(job_id, args['provider'])
result = WebsiteService.get_crawl_status(job_id, args["provider"])
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
api.add_resource(WebsiteCrawlApi, '/website/crawl')
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
api.add_resource(WebsiteCrawlApi, "/website/crawl")
api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/<string:job_id>")

View File

@@ -2,35 +2,41 @@ from libs.exception import BaseHTTPException
class AlreadySetupError(BaseHTTPException):
error_code = 'already_setup'
error_code = "already_setup"
description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
code = 403
class NotSetupError(BaseHTTPException):
error_code = 'not_setup'
description = "Dify has not been initialized and installed yet. " \
"Please proceed with the initialization and installation process first."
error_code = "not_setup"
description = (
"Dify has not been initialized and installed yet. "
"Please proceed with the initialization and installation process first."
)
code = 401
class NotInitValidateError(BaseHTTPException):
error_code = 'not_init_validated'
description = "Init validation has not been completed yet. " \
"Please proceed with the init validation process first."
error_code = "not_init_validated"
description = (
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
)
code = 401
class InitValidateFailedError(BaseHTTPException):
error_code = 'init_validate_failed'
error_code = "init_validate_failed"
description = "Init validation failed. Please check the password and try again."
code = 401
class AccountNotLinkTenantError(BaseHTTPException):
error_code = 'account_not_link_tenant'
error_code = "account_not_link_tenant"
description = "Account not link tenant."
code = 403
class AlreadyActivateError(BaseHTTPException):
error_code = 'already_activate'
error_code = "already_activate"
description = "Auth Token is invalid or account already activated, please check again."
code = 403

View File

@@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
file = request.files['file']
file = request.files["file"]
try:
response = AudioService.transcript_asr(
app_model=app_model,
file=file,
end_user=None
)
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@@ -76,30 +72,31 @@ class ChatTextApi(InstalledAppResource):
app_model = installed_app.app
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
message_id=message_id,
voice=voice,
text=text
)
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
@@ -127,7 +124,7 @@ class ChatTextApi(InstalledAppResource):
raise InternalServerError()
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
# endpoint='installed_app_text_with_message_id')

View File

@@ -30,33 +30,28 @@ from services.app_generate_service import AppGenerateService
# define completion api for user
class CompletionApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != 'completion':
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
return helper.compact_generate_response(response)
@@ -85,12 +80,12 @@ class CompletionApi(InstalledAppResource):
class CompletionStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
if app_model.mode != 'completion':
if app_model.mode != "completion":
raise NotCompletionAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
class ChatApi(InstalledAppResource):
@@ -101,25 +96,21 @@ class ChatApi(InstalledAppResource):
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args['auto_generate_name'] = False
args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=True
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
return helper.compact_generate_response(response)
@@ -154,10 +145,22 @@ class ChatStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
api.add_resource(CompletionApi, '/installed-apps/<uuid:installed_app_id>/completion-messages', endpoint='installed_app_completion')
api.add_resource(CompletionStopApi, '/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop', endpoint='installed_app_stop_completion')
api.add_resource(ChatApi, '/installed-apps/<uuid:installed_app_id>/chat-messages', endpoint='installed_app_chat_completion')
api.add_resource(ChatStopApi, '/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop', endpoint='installed_app_stop_chat_completion')
api.add_resource(
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
)
api.add_resource(
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)

View File

@@ -16,7 +16,6 @@ from services.web_conversation_service import WebConversationService
class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, installed_app):
app_model = installed_app.app
@@ -25,21 +24,21 @@ class ConversationListApi(InstalledAppResource):
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
args = parser.parse_args()
pinned = None
if 'pinned' in args and args['pinned'] is not None:
pinned = True if args['pinned'] == 'true' else False
if "pinned" in args and args["pinned"] is not None:
pinned = True if args["pinned"] == "true" else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=current_user,
last_id=args['last_id'],
limit=args['limit'],
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
@@ -65,7 +64,6 @@ class ConversationApi(InstalledAppResource):
class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id):
app_model = installed_app.app
@@ -76,24 +74,19 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
parser.add_argument("name", type=str, required=False, location="json")
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
args = parser.parse_args()
try:
return ConversationService.rename(
app_model,
conversation_id,
current_user,
args['name'],
args['auto_generate']
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
@@ -123,8 +116,26 @@ class ConversationUnPinApi(InstalledAppResource):
return {"result": "success"}
api.add_resource(ConversationRenameApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name', endpoint='installed_app_conversation_rename')
api.add_resource(ConversationListApi, '/installed-apps/<uuid:installed_app_id>/conversations', endpoint='installed_app_conversations')
api.add_resource(ConversationApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>', endpoint='installed_app_conversation')
api.add_resource(ConversationPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin', endpoint='installed_app_conversation_pin')
api.add_resource(ConversationUnPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin', endpoint='installed_app_conversation_unpin')
api.add_resource(
ConversationRenameApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)

View File

@@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException
class NotCompletionAppError(BaseHTTPException):
error_code = 'not_completion_app'
error_code = "not_completion_app"
description = "Not Completion App"
code = 400
class NotChatAppError(BaseHTTPException):
error_code = 'not_chat_app'
error_code = "not_chat_app"
description = "App mode is invalid."
code = 400
class NotWorkflowAppError(BaseHTTPException):
error_code = 'not_workflow_app'
error_code = "not_workflow_app"
description = "Only support workflow app."
code = 400
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
error_code = 'app_suggested_questions_after_answer_disabled'
error_code = "app_suggested_questions_after_answer_disabled"
description = "Function Suggested questions after answer disabled."
code = 403

View File

@@ -21,72 +21,72 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields)
def get(self):
current_tenant_id = current_user.current_tenant_id
installed_apps = db.session.query(InstalledApp).filter(
InstalledApp.tenant_id == current_tenant_id
).all()
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_apps = [
{
'id': installed_app.id,
'app': installed_app.app,
'app_owner_tenant_id': installed_app.app_owner_tenant_id,
'is_pinned': installed_app.is_pinned,
'last_used_at': installed_app.last_used_at,
'editable': current_user.role in ["owner", "admin"],
'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id
"id": installed_app.id,
"app": installed_app.app,
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at,
"editable": current_user.role in ["owner", "admin"],
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
}
for installed_app in installed_apps
if installed_app.app is not None
]
installed_apps.sort(key=lambda app: (-app['is_pinned'],
app['last_used_at'] is None,
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
installed_apps.sort(
key=lambda app: (
-app["is_pinned"],
app["last_used_at"] is None,
-app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0,
)
)
return {'installed_apps': installed_apps}
return {"installed_apps": installed_apps}
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('apps')
@cloud_edition_billing_resource_check("apps")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None:
raise NotFound('App not found')
raise NotFound("App not found")
current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).filter(
App.id == args['app_id']
).first()
app = db.session.query(App).filter(App.id == args["app_id"]).first()
if app is None:
raise NotFound('App not found')
raise NotFound("App not found")
if not app.is_public:
raise Forbidden('You can\'t install a non-public app')
raise Forbidden("You can't install a non-public app")
installed_app = InstalledApp.query.filter(and_(
InstalledApp.app_id == args['app_id'],
InstalledApp.tenant_id == current_tenant_id
)).first()
installed_app = InstalledApp.query.filter(
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
).first()
if installed_app is None:
# todo: position
recommended_app.install_count += 1
new_installed_app = InstalledApp(
app_id=args['app_id'],
app_id=args["app_id"],
tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id,
is_pinned=False,
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None)
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None),
)
db.session.add(new_installed_app)
db.session.commit()
return {'message': 'App installed successfully'}
return {"message": "App installed successfully"}
class InstalledAppApi(InstalledAppResource):
@@ -94,30 +94,31 @@ class InstalledAppApi(InstalledAppResource):
update and delete an installed app
use InstalledAppResource to apply default decorators and get installed_app
"""
def delete(self, installed_app):
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest('You can\'t uninstall an app owned by the current tenant')
raise BadRequest("You can't uninstall an app owned by the current tenant")
db.session.delete(installed_app)
db.session.commit()
return {'result': 'success', 'message': 'App uninstalled successfully'}
return {"result": "success", "message": "App uninstalled successfully"}
def patch(self, installed_app):
parser = reqparse.RequestParser()
parser.add_argument('is_pinned', type=inputs.boolean)
parser.add_argument("is_pinned", type=inputs.boolean)
args = parser.parse_args()
commit_args = False
if 'is_pinned' in args:
installed_app.is_pinned = args['is_pinned']
if "is_pinned" in args:
installed_app.is_pinned = args["is_pinned"]
commit_args = True
if commit_args:
db.session.commit()
return {'result': 'success', 'message': 'App info updated successfully'}
return {"result": "success", "message": "App info updated successfully"}
api.add_resource(InstalledAppsListApi, '/installed-apps')
api.add_resource(InstalledAppApi, '/installed-apps/<uuid:installed_app_id>')
api.add_resource(InstalledAppsListApi, "/installed-apps")
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")

View File

@@ -44,19 +44,21 @@ class MessageListApi(InstalledAppResource):
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
try:
return MessageService.pagination_by_first_id(app_model, current_user,
args['conversation_id'], args['first_id'], args['limit'])
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.message.FirstMessageNotExistsError:
raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id):
app_model = installed_app.app
@@ -64,30 +66,32 @@ class MessageFeedbackApi(InstalledAppResource):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success'}
return {"result": "success"}
class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
if app_model.mode != 'completion':
if app_model.mode != "completion":
raise NotCompletionAppError()
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
parser.add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
streaming = args["response_mode"] == "streaming"
try:
response = AppGenerateService.generate_more_like_this(
@@ -95,7 +99,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
streaming=streaming,
)
return helper.compact_generate_response(response)
except MessageNotExistsError:
@@ -128,10 +132,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
@@ -151,10 +152,22 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
logging.exception("internal server error.")
raise InternalServerError()
return {'data': questions}
return {"data": questions}
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
api.add_resource(
MessageFeedbackApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)

View File

@@ -1,4 +1,3 @@
from flask_restful import fields, marshal_with
from configs import dify_config
@@ -11,33 +10,32 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
variable_fields = {
'key': fields.String,
'name': fields.String,
'description': fields.String,
'type': fields.String,
'default': fields.String,
'max_length': fields.Integer,
'options': fields.List(fields.String)
"key": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"default": fields.String,
"max_length": fields.Integer,
"options": fields.List(fields.String),
}
system_parameters_fields = {
'image_file_size_limit': fields.String
}
system_parameters_fields = {"image_file_size_limit": fields.String}
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'text_to_speech': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
'file_upload': fields.Raw,
'system_parameters': fields.Nested(system_parameters_fields)
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(system_parameters_fields),
}
@marshal_with(parameters_fields)
@@ -56,30 +54,35 @@ class AppParameterApi(InstalledAppResource):
app_model_config = app_model.app_model_config
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get('user_input_form', [])
user_input_form = features_dict.get("user_input_form", [])
return {
'opening_statement': features_dict.get('opening_statement'),
'suggested_questions': features_dict.get('suggested_questions', []),
'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
{"enabled": False}),
'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
'user_input_form': user_input_form,
'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
{"enabled": False, "type": "", "configs": []}),
'file_upload': features_dict.get('file_upload', {"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"]
}}),
'system_parameters': {
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
}
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get(
"suggested_questions_after_answer", {"enabled": False}
),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
}
@@ -90,6 +93,7 @@ class ExploreAppMetaApi(InstalledAppResource):
return AppService().get_app_meta(app_model)
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters',
endpoint='installed_app_parameters')
api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
api.add_resource(
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
)
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@@ -8,28 +8,28 @@ from libs.login import login_required
from services.recommended_app_service import RecommendedAppService
app_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
"id": fields.String,
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_background": fields.String,
}
recommended_app_fields = {
'app': fields.Nested(app_fields, attribute='app'),
'app_id': fields.String,
'description': fields.String(attribute='description'),
'copyright': fields.String,
'privacy_policy': fields.String,
'custom_disclaimer': fields.String,
'category': fields.String,
'position': fields.Integer,
'is_listed': fields.Boolean
"app": fields.Nested(app_fields, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
}
recommended_app_list_fields = {
'recommended_apps': fields.List(fields.Nested(recommended_app_fields)),
'categories': fields.List(fields.String)
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
"categories": fields.List(fields.String),
}
@@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource):
def get(self):
# language args
parser = reqparse.RequestParser()
parser.add_argument('language', type=str, location='args')
parser.add_argument("language", type=str, location="args")
args = parser.parse_args()
if args.get('language') and args.get('language') in languages:
language_prefix = args.get('language')
if args.get("language") and args.get("language") in languages:
language_prefix = args.get("language")
elif current_user and current_user.interface_language:
language_prefix = current_user.interface_language
else:
@@ -61,5 +61,5 @@ class RecommendedAppApi(Resource):
return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, '/explore/apps')
api.add_resource(RecommendedAppApi, '/explore/apps/<uuid:app_id>')
api.add_resource(RecommendedAppListApi, "/explore/apps")
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")

View File

@@ -11,56 +11,54 @@ from libs.helper import TimestampField, uuid_value
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
feedback_fields = {
'rating': fields.String
}
feedback_fields = {"rating": fields.String}
message_fields = {
'id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
"id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app):
app_model = installed_app.app
if app_model.mode != 'completion':
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit'])
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != 'completion':
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=uuid_value, required=True, location='json')
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args()
try:
SavedMessageService.save(app_model, current_user, args['message_id'])
SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success'}
return {"result": "success"}
class SavedMessageApi(InstalledAppResource):
@@ -69,13 +67,21 @@ class SavedMessageApi(InstalledAppResource):
message_id = str(message_id)
if app_model.mode != 'completion':
if app_model.mode != "completion":
raise NotCompletionAppError()
SavedMessageService.delete(app_model, current_user, message_id)
return {'result': 'success'}
return {"result": "success"}
api.add_resource(SavedMessageListApi, '/installed-apps/<uuid:installed_app_id>/saved-messages', endpoint='installed_app_saved_messages')
api.add_resource(SavedMessageApi, '/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>', endpoint='installed_app_saved_message')
api.add_resource(
SavedMessageListApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)

View File

@@ -35,17 +35,13 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=True
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
return helper.compact_generate_response(response)
@@ -76,10 +72,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {
"result": "success"
}
return {"result": "success"}
api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps/<uuid:installed_app_id>/workflows/run')
api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop')
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)

View File

@@ -14,29 +14,33 @@ def installed_app_required(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
if not kwargs.get('installed_app_id'):
raise ValueError('missing installed_app_id in path parameters')
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
installed_app_id = kwargs.get('installed_app_id')
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)
del kwargs['installed_app_id']
del kwargs["installed_app_id"]
installed_app = db.session.query(InstalledApp).filter(
InstalledApp.id == str(installed_app_id),
InstalledApp.tenant_id == current_user.current_tenant_id
).first()
installed_app = (
db.session.query(InstalledApp)
.filter(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.first()
)
if installed_app is None:
raise NotFound('Installed app not found')
raise NotFound("Installed app not found")
if not installed_app.app:
db.session.delete(installed_app)
db.session.commit()
raise NotFound('Installed app not found')
raise NotFound("Installed app not found")
return view(installed_app, *args, **kwargs)
return decorated
if view:

View File

@@ -13,23 +13,18 @@ from services.code_based_extension_service import CodeBasedExtensionService
class CodeBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('module', type=str, required=True, location='args')
parser.add_argument("module", type=str, required=True, location="args")
args = parser.parse_args()
return {
'module': args['module'],
'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
}
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -44,23 +39,22 @@ class APIBasedExtensionAPI(Resource):
@marshal_with(api_based_extension_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('api_endpoint', type=str, required=True, location='json')
parser.add_argument('api_key', type=str, required=True, location='json')
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
args = parser.parse_args()
extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id,
name=args['name'],
api_endpoint=args['api_endpoint'],
api_key=args['api_key']
name=args["name"],
api_endpoint=args["api_endpoint"],
api_key=args["api_key"],
)
return APIBasedExtensionService.save(extension_data)
class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -82,16 +76,16 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('api_endpoint', type=str, required=True, location='json')
parser.add_argument('api_key', type=str, required=True, location='json')
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
args = parser.parse_args()
extension_data_from_db.name = args['name']
extension_data_from_db.api_endpoint = args['api_endpoint']
extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"]
if args['api_key'] != HIDDEN_VALUE:
extension_data_from_db.api_key = args['api_key']
if args["api_key"] != HIDDEN_VALUE:
extension_data_from_db.api_key = args["api_key"]
return APIBasedExtensionService.save(extension_data_from_db)
@@ -106,10 +100,10 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db)
return {'result': 'success'}
return {"result": "success"}
api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')
api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")

View File

@@ -10,7 +10,6 @@ from .wraps import account_initialization_required, cloud_utm_record
class FeatureApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -24,5 +23,5 @@ class SystemFeatureApi(Resource):
return FeatureService.get_system_features().model_dump()
api.add_resource(FeatureApi, '/features')
api.add_resource(SystemFeatureApi, '/system-features')
api.add_resource(FeatureApi, "/features")
api.add_resource(SystemFeatureApi, "/system-features")

View File

@@ -14,12 +14,11 @@ from .wraps import only_edition_self_hosted
class InitValidateAPI(Resource):
def get(self):
init_status = get_init_validate_status()
if init_status:
return { 'status': 'finished' }
return {'status': 'not_started' }
return {"status": "finished"}
return {"status": "not_started"}
@only_edition_self_hosted
def post(self):
@@ -29,22 +28,23 @@ class InitValidateAPI(Resource):
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument('password', type=str_len(30),
required=True, location='json')
input_password = parser.parse_args()['password']
parser.add_argument("password", type=str_len(30), required=True, location="json")
input_password = parser.parse_args()["password"]
if input_password != os.environ.get('INIT_PASSWORD'):
session['is_init_validated'] = False
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
session['is_init_validated'] = True
return {'result': 'success'}, 201
session["is_init_validated"] = True
return {"result": "success"}, 201
def get_init_validate_status():
if dify_config.EDITION == 'SELF_HOSTED':
if os.environ.get('INIT_PASSWORD'):
return session.get('is_init_validated') or DifySetup.query.first()
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
return session.get("is_init_validated") or DifySetup.query.first()
return True
api.add_resource(InitValidateAPI, '/init')
api.add_resource(InitValidateAPI, "/init")

View File

@@ -4,14 +4,11 @@ from controllers.console import api
class PingApi(Resource):
def get(self):
"""
For connection health check
"""
return {
"result": "pong"
}
return {"result": "pong"}
api.add_resource(PingApi, '/ping')
api.add_resource(PingApi, "/ping")

View File

@@ -16,17 +16,13 @@ from .wraps import only_edition_self_hosted
class SetupApi(Resource):
def get(self):
if dify_config.EDITION == 'SELF_HOSTED':
if dify_config.EDITION == "SELF_HOSTED":
setup_status = get_setup_status()
if setup_status:
return {
'step': 'finished',
'setup_at': setup_status.setup_at.isoformat()
}
return {'step': 'not_started'}
return {'step': 'finished'}
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
return {"step": "not_started"}
return {"step": "finished"}
@only_edition_self_hosted
def post(self):
@@ -38,28 +34,22 @@ class SetupApi(Resource):
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
if not get_init_validate_status():
raise NotInitValidateError()
parser = reqparse.RequestParser()
parser.add_argument('email', type=email,
required=True, location='json')
parser.add_argument('name', type=str_len(
30), required=True, location='json')
parser.add_argument('password', type=valid_password,
required=True, location='json')
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=str_len(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()
# setup
RegisterService.setup(
email=args['email'],
name=args['name'],
password=args['password'],
ip_address=get_remote_ip(request)
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
)
return {'result': 'success'}, 201
return {"result": "success"}, 201
def setup_required(view):
@@ -68,7 +58,7 @@ def setup_required(view):
# check setup
if not get_init_validate_status():
raise NotInitValidateError()
elif not get_setup_status():
raise NotSetupError()
@@ -78,9 +68,10 @@ def setup_required(view):
def get_setup_status():
if dify_config.EDITION == 'SELF_HOSTED':
if dify_config.EDITION == "SELF_HOSTED":
return DifySetup.query.first()
else:
return True
api.add_resource(SetupApi, '/setup')
api.add_resource(SetupApi, "/setup")

View File

@@ -14,19 +14,18 @@ from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 50 characters.')
raise ValueError("Name must be between 1 to 50 characters.")
return name
class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(tag_fields)
def get(self):
tag_type = request.args.get('type', type=str)
keyword = request.args.get('keyword', default=None, type=str)
tag_type = request.args.get("type", type=str)
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
return tags, 200
@@ -40,28 +39,21 @@ class TagListApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='Name must be between 1 to 50 characters.',
type=_validate_name)
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
parser.add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
args = parser.parse_args()
tag = TagService.save_tags(args)
response = {
'id': tag.id,
'name': tag.name,
'type': tag.type,
'binding_count': 0
}
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -72,20 +64,15 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='Name must be between 1 to 50 characters.',
type=_validate_name)
parser.add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
args = parser.parse_args()
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
response = {
'id': tag.id,
'name': tag.name,
'type': tag.type,
'binding_count': binding_count
}
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
return response, 200
@@ -104,7 +91,6 @@ class TagUpdateDeleteApi(Resource):
class TagBindingCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -114,14 +100,15 @@ class TagBindingCreateApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
help='Tag IDs is required.')
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
help='Target ID is required.')
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
parser.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
parser.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
)
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
args = parser.parse_args()
TagService.save_tag_binding(args)
@@ -129,7 +116,6 @@ class TagBindingCreateApi(Resource):
class TagBindingDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -139,21 +125,18 @@ class TagBindingDeleteApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('tag_id', type=str, nullable=False, required=True,
help='Tag ID is required.')
parser.add_argument('target_id', type=str, nullable=False, required=True,
help='Target ID is required.')
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
args = parser.parse_args()
TagService.delete_tag_binding(args)
return 200
api.add_resource(TagListApi, '/tags')
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')
api.add_resource(TagListApi, "/tags")
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")

View File

@@ -1,4 +1,3 @@
import json
import logging
@@ -11,42 +10,39 @@ from . import api
class VersionApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('current_version', type=str, required=True, location='args')
parser.add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args()
check_update_url = dify_config.CHECK_UPDATE_URL
result = {
'version': dify_config.CURRENT_VERSION,
'release_date': '',
'release_notes': '',
'can_auto_update': False,
'features': {
'can_replace_logo': dify_config.CAN_REPLACE_LOGO,
'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED
}
"version": dify_config.CURRENT_VERSION,
"release_date": "",
"release_notes": "",
"can_auto_update": False,
"features": {
"can_replace_logo": dify_config.CAN_REPLACE_LOGO,
"model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
},
}
if not check_update_url:
return result
try:
response = requests.get(check_update_url, {
'current_version': args.get('current_version')
})
response = requests.get(check_update_url, {"current_version": args.get("current_version")})
except Exception as error:
logging.warning("Check update version error: {}.".format(str(error)))
result['version'] = args.get('current_version')
result["version"] = args.get("current_version")
return result
content = json.loads(response.content)
result['version'] = content['version']
result['release_date'] = content['releaseDate']
result['release_notes'] = content['releaseNotes']
result['can_auto_update'] = content['canAutoUpdate']
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]
result["can_auto_update"] = content["canAutoUpdate"]
return result
api.add_resource(VersionApi, '/version')
api.add_resource(VersionApi, "/version")

Some files were not shown because too many files have changed in this diff Show More