Compare commits

..

138 Commits

Author SHA1 Message Date
-LAN-
4b0d2bf57f chore: update build-push.yml to remove unnecessary tags 2024-09-04 20:20:39 +08:00
-LAN-
94432a0a69 chore: update package versions to 0.8.0-beta1 (#7979) 2024-09-04 19:56:33 +08:00
takatost
7e30487f8b feat: update dsl version 2024-09-04 19:41:43 +08:00
Yi
46634638e7 fix: refine the "isInIteration" for workflow 2024-09-04 17:45:07 +08:00
StyleZhang
44038b9628 fix: iteration copy 2024-09-04 17:26:43 +08:00
StyleZhang
c625f4282f Merge branch 'main' into feat/workflow-parallel-support 2024-09-04 15:22:57 +08:00
StyleZhang
4f5dc82459 fix 2024-09-04 15:03:35 +08:00
Yi
5cb018e15d update the method to check if a node is in iteration 2024-09-04 14:59:30 +08:00
StyleZhang
4962b2c460 check node edge 2024-09-04 13:27:17 +08:00
Yi
cd42dbdae8 update the log for iteration nodes 2024-09-04 10:27:40 +08:00
takatost
78fa1f6868 fix(workflow): detached session issues 2024-09-03 18:23:37 +08:00
Yi
6bee121ebe update log in web app 2024-09-03 17:27:32 +08:00
takatost
36d95e49b0 fix(iteration): iterator_length not correct 2024-09-03 12:01:56 +08:00
Yi
3431b19f9a update styling and iteration log 2024-09-03 11:46:04 +08:00
Yi
b28c7b1cda Merge branch 'feat/workflow-parallel-support' of github.com:langgenius/dify into feat/workflow-parallel-support 2024-09-03 10:35:15 +08:00
Yi
83343eefe6 update parallel log 2024-09-03 10:34:50 +08:00
takatost
d92966545b fix: migration 2024-09-02 22:41:08 +08:00
takatost
f71c51cb9a Merge branch 'refs/heads/main' into feat/workflow-parallel-support 2024-09-02 22:37:23 +08:00
takatost
955884b87e chore(workflow): max thread submit count 2024-09-02 20:20:32 +08:00
takatost
5ca9df65de feat(workflow): add thread pool 2024-09-02 19:02:45 +08:00
takatost
166365a502 feat(workflow): add thread pool 2024-09-02 19:02:21 +08:00
StyleZhang
70aced0100 fix 2024-09-02 18:38:21 +08:00
takatost
35d9c59a29 Merge remote-tracking branch 'origin/feat/workflow-parallel-support' into feat/workflow-parallel-support 2024-09-02 17:56:16 +08:00
takatost
bbc922dffa merge main 2024-09-02 17:55:28 +08:00
StyleZhang
7035f64ce3 fix: next step 2024-09-02 17:52:54 +08:00
takatost
81d09d471c Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/workflow/app_generator.py
2024-09-02 17:52:51 +08:00
takatost
5bda3a384a fix(workflow): bugs 2024-09-02 17:49:51 +08:00
StyleZhang
43240fcd41 fix 2024-09-02 14:50:05 +08:00
takatost
52b4623131 fix(workflow): fix merge branch node id err 2024-09-02 13:56:07 +08:00
takatost
0dabf799c0 fix(workflow): fix merge branch node id err 2024-09-02 11:52:14 +08:00
Yi
29b1ce781d fix: node end status 2024-09-01 22:00:54 +08:00
Yi
71a7d890cc fix styling 2024-08-30 23:31:05 +08:00
Yi
ee1587c939 fix: make the End node always nested in the root 2024-08-30 20:14:56 +08:00
Yi
d7c0ca852e feat: inner parallels will be added to its corresponding branch 2024-08-30 20:08:57 +08:00
takatost
162e9677c7 fix(workflow): missing parallel event in workflow app 2024-08-30 20:04:17 +08:00
takatost
77e62f7fee fix(workflow): run node in multi parallel bugs 2024-08-30 18:55:33 +08:00
Yi
e3295181d2 fix a typo 2024-08-30 18:01:13 +08:00
Yi
2b5b856126 solve the branch issue 2024-08-30 17:58:29 +08:00
Yi
e3ae529a55 update the onNodeFinished method for nodes being passed through more than once 2024-08-30 17:00:02 +08:00
Yi
708256ef1d Merge branch 'feat/workflow-parallel-support' of github.com:langgenius/dify into feat/workflow-parallel-support 2024-08-30 15:23:21 +08:00
StyleZhang
7c9081a8fc fix 2024-08-30 13:44:01 +08:00
Yi
1bde57e591 delete console logs 2024-08-29 17:54:26 +08:00
Yi
32a11cbb6a update the parallel workflow log for iteration and chatflow preview 2024-08-29 17:26:17 +08:00
Yi
3e257ae907 update the workflow parallel log 2024-08-29 16:38:51 +08:00
StyleZhang
f43596f226 fix: parallel branch limit 2024-08-29 11:31:34 +08:00
takatost
ae22015fe7 fix(workflow): loop check 2024-08-28 21:47:47 +08:00
takatost
790dd3b22f fix(workflow): duplicate nodes in parallel 2024-08-28 19:01:45 +08:00
takatost
5d34e080eb fix: migration 2024-08-28 18:02:49 +08:00
takatost
6b6750b9ad Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/services/app_generate_service.py
2024-08-28 18:01:57 +08:00
takatost
74c8004944 fix(graph_engine): fix execute loops in parallel 2024-08-28 17:42:42 +08:00
StyleZhang
4418fa1d2b fix: bug 2024-08-28 17:40:50 +08:00
takatost
c2bb11405f fix(workflow): parallel not yield 2024-08-28 16:13:57 +08:00
StyleZhang
8ba5673606 feat: iteration support parallel 2024-08-28 16:00:17 +08:00
takatost
b0a81c654b fix(workflow): parallel execution after if-else that only one branch runs 2024-08-28 15:53:39 +08:00
takatost
cd52633b0e fix(graph_engine): parent_parallel_id missing 2024-08-27 16:45:14 +08:00
takatost
4256e9d47f chore(iteration): keep start_node_id using in parallel start nodes 2024-08-27 16:38:33 +08:00
StyleZhang
4e3dc36e37 fix: workflow run edge status 2024-08-27 14:39:56 +08:00
takatost
b9f34f679f fix: iteration start node id 2024-08-26 22:00:17 +08:00
StyleZhang
9c8144e463 feat: parallel hover 2024-08-26 17:49:11 +08:00
takatost
76bb8d1c1a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/services/app_generate_service.py
#	api/services/workflow_service.py
2024-08-26 16:17:19 +08:00
StyleZhang
1016db160e feat: parallel hover 2024-08-26 16:09:22 +08:00
takatost
6c61776ee1 fix test 2024-08-25 22:02:21 +08:00
takatost
4771e85630 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/tests/integration_tests/workflow/nodes/test_code.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-24 17:26:44 +08:00
takatost
85d319719c fix end node bug 2024-08-24 17:17:18 +08:00
takatost
42899fb3be fix bug 2024-08-23 00:38:42 +08:00
takatost
5b22d8f8b2 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/question_classifier/question_classifier_node.py
2024-08-23 00:32:28 +08:00
takatost
fe2b300288 fix lint 2024-08-22 23:54:07 +08:00
takatost
ec4fc784f0 fix iteration start node 2024-08-22 23:53:44 +08:00
takatost
d6da7b0336 fix dialogue_count 2024-08-22 13:06:17 +08:00
takatost
92072e2ed7 fix: ruff issues 2024-08-21 17:26:51 +08:00
takatost
e34497ded1 fix: merge issues 2024-08-21 17:25:26 +08:00
takatost
35be41b337 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/start/start_node.py
#	api/core/workflow/nodes/variable_assigner/__init__.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-21 16:59:23 +08:00
takatost
412be6d014 fix bug 2024-08-21 16:43:00 +08:00
takatost
1d88b62e25 fix(workflow): fix node link to previous node issue 2024-08-20 23:28:11 +08:00
takatost
617ea4b3b8 fix(workflow): fix parallel bug 2024-08-20 22:16:41 +08:00
takatost
755a9658c7 fix(workflow): add parallel id into published events 2024-08-18 20:18:13 +08:00
takatost
5d7865737f fix(workflow): issues in workflow parallels 2024-08-16 22:47:58 +08:00
takatost
352c45c8a2 feat(workflow): integrate parallel into workflow apps 2024-08-16 21:33:09 +08:00
StyleZhang
1973f5003b feat: frontend support parallel 2024-08-16 16:55:08 +08:00
takatost
5b5e6e31bf fix: answer node unit tests 2024-08-16 01:44:00 +08:00
takatost
90221c0a90 fix: unit tests 2024-08-16 01:43:35 +08:00
takatost
91e51ce2b8 fix(workflow): issues by merging main branch 2024-08-16 01:36:19 +08:00
takatost
db9b0ee985 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/base_app_runner.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/node_entities.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/workflow_engine_manager.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-16 01:19:29 +08:00
takatost
c5192650fb fix: unit tests in workflow 2024-08-15 23:47:59 +08:00
takatost
702df31db7 fix(workflow): fix generate issues in workflow 2024-08-15 20:45:23 +08:00
takatost
1da5862a96 feat(workflow): fix iteration single debug 2024-08-15 03:12:49 +08:00
takatost
6f6b32e1ee feat(workflow): integrate workflow entry with workflow app 2024-08-14 19:22:15 +08:00
takatost
674af04c39 fix migration version depends 2024-08-13 17:15:21 +08:00
takatost
2980e31ddf fix issues when merging from main 2024-08-13 17:11:19 +08:00
takatost
14d020fffe Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/task_pipeline/workflow_cycle_manage.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/base_node.py
#	api/core/workflow/workflow_engine_manager.py
2024-08-13 17:05:39 +08:00
takatost
8401a11109 feat(workflow): integrate workflow entry with advanced chat app 2024-08-13 16:21:10 +08:00
takatost
8d27ec364f fix bug 2024-07-31 02:27:23 +08:00
takatost
c9bb366e1a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/iteration/iteration_node.py
#	api/core/workflow/workflow_engine_manager.py
2024-07-31 02:25:31 +08:00
takatost
917aacbf7f add chatflow app event convert 2024-07-31 02:21:35 +08:00
takatost
0818b7b078 remove iteration special logic 2024-07-26 21:27:01 +08:00
takatost
88dcd7b737 fix bug 2024-07-26 20:29:12 +08:00
takatost
63addf8c94 add parallel branch events 2024-07-26 20:27:17 +08:00
takatost
483f71f03c fix logging 2024-07-26 20:13:11 +08:00
takatost
beea1e1663 fix lint 2024-07-26 19:47:12 +08:00
takatost
38f8c45755 add events in interation node 2024-07-26 19:47:02 +08:00
takatost
a31feacf28 fix iteration 2024-07-26 02:43:40 +08:00
takatost
ae351bd40e add iteration support 2024-07-25 23:07:27 +08:00
takatost
df133168dd fix lint 2024-07-25 21:06:23 +08:00
takatost
7c67ba8991 remove threadpool 2024-07-25 21:05:53 +08:00
takatost
4097f7c069 add parallel branch output 2024-07-25 19:39:06 +08:00
takatost
f4eb7cd037 add end stream output test 2024-07-25 04:03:53 +08:00
takatost
833584ba76 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/workflow_entry.py
2024-07-24 23:43:14 +08:00
takatost
ec7760795f save 2024-07-24 00:24:24 +08:00
takatost
e9bfedab9b Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/entities/variable_pool.py
2024-07-23 17:28:57 +08:00
takatost
7303b53af1 fix bug 2024-07-23 16:18:52 +08:00
takatost
0fe516568a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/end/end_node.py
#	api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
#	api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
#	api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
2024-07-23 16:18:34 +08:00
takatost
2c695ded79 fix bugs 2024-07-23 00:10:23 +08:00
takatost
a603e01f5e fix bug 2024-07-22 19:57:32 +08:00
takatost
beaac5033a fix bug 2024-07-20 00:57:41 +08:00
takatost
dad1a967ee finished answer stream output 2024-07-20 00:49:46 +08:00
takatost
7ad77e9e77 fix test 2024-07-18 08:19:58 +08:00
takatost
f67a88f44d fix test 2024-07-17 21:17:04 +08:00
takatost
90e518b05b fix bugs 2024-07-17 16:54:49 +08:00
takatost
cc96acdae3 fix bugs 2024-07-17 11:26:33 +08:00
takatost
16e2d00157 optimize 2024-07-17 01:07:23 +08:00
takatost
4ef3d4e65c optimize 2024-07-17 01:02:40 +08:00
takatost
775e52db4d merge 2024-07-16 17:46:20 +08:00
takatost
00ec36d47c add graph engine test 2024-07-16 16:37:37 +08:00
takatost
00fb23d0c9 graph engine implement 2024-07-15 23:40:02 +08:00
takatost
821e09b259 add run logics 2024-07-12 19:33:47 +08:00
takatost
d77b689a99 completed parallel tests 2024-07-10 21:21:06 +08:00
takatost
0e885a3cae refactor runtime 2024-07-08 16:29:13 +08:00
takatost
1adaf42f9d refactor graph 2024-07-07 23:08:45 +08:00
takatost
fed068ac2e Merge branch 'refs/heads/main' into feat/workflow-parallel-support 2024-07-07 16:57:21 +08:00
takatost
03f56a05eb refactor graph 2024-07-06 03:18:02 +08:00
takatost
1b6cd975f3 completed graph init test 2024-07-04 15:40:20 +08:00
takatost
0f19b2a986 optimize graph 2024-07-02 21:53:41 +08:00
takatost
8375517ccd save 2024-06-29 15:44:52 +08:00
takatost
1d8ecac093 save 2024-06-27 05:30:38 +08:00
takatost
aaa98c76d5 optimize 2024-06-26 23:56:30 +08:00
takatost
216910a4a1 add runtime state of graph 2024-06-25 17:43:13 +08:00
takatost
fe27c97fd9 add runtime graph 2024-06-25 14:41:14 +08:00
takatost
8217c46116 add new graph structure 2024-06-24 23:34:42 +08:00
1733 changed files with 29863 additions and 53350 deletions

View File

@@ -125,7 +125,6 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@@ -20,7 +20,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v44
with:
files: api/**
@@ -66,7 +66,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v44
with:
files: web/**
@@ -97,7 +97,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v44
with:
files: |
**.sh
@@ -107,7 +107,7 @@ jobs:
dev/**
- name: Super-linter
uses: super-linter/super-linter/slim@v7
uses: super-linter/super-linter/slim@v6
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

9
.gitignore vendored
View File

@@ -153,9 +153,6 @@ docker-legacy/volumes/etcd/*
docker-legacy/volumes/minio/*
docker-legacy/volumes/milvus/*
docker-legacy/volumes/chroma/*
docker-legacy/volumes/opensearch/data/*
docker-legacy/volumes/pgvectors/data/*
docker-legacy/volumes/pgvector/data/*
docker/volumes/app/storage/*
docker/volumes/certbot/*
@@ -167,12 +164,6 @@ docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
docker/volumes/chroma/*
docker/volumes/opensearch/data/*
docker/volumes/myscale/data/*
docker/volumes/myscale/log/*
docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/nginx/conf.d/default.conf
docker/middleware.env

View File

@@ -36,7 +36,7 @@
| 被团队成员标记为高优先级的功能 | 高优先级 |
| 在 [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 |
| 非核心功能和小幅改进 | 低优先级 |
| 有价值不紧急 | 未来功能 |
| 有价值不紧急 | 未来功能 |
### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正):
* 立即开始编码。
@@ -138,7 +138,7 @@ Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsproject
├── models // 描述数据模型和 API 响应的形状
├── public // 如 favicon 等元资源
├── service // 定义 API 操作的形状
├── test
├── test
├── types // 函数参数和返回值的描述
└── utils // 共享的实用函数
```

View File

@@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.

View File

@@ -39,7 +39,7 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos
# storage type: local, s3, azure-blob, google-storage
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false
@@ -73,12 +73,6 @@ TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
# OCI Storage configuration
OCI_ENDPOINT=your-endpoint
OCI_BUCKET_NAME=your-bucket-name
@@ -86,13 +80,6 @@ OCI_ACCESS_KEY=your-access-key
OCI_SECRET_KEY=your-secret-key
OCI_REGION=your-region
# Volcengine tos Storage configuration
VOLCENGINE_TOS_ENDPOINT=your-endpoint
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
VOLCENGINE_TOS_REGION=your-region
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@@ -114,10 +101,11 @@ QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334
# Milvus configuration
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
MILVUS_HOST=127.0.0.1
MILVUS_PORT=19530
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# MyScale configuration
MYSCALE_HOST=127.0.0.1

View File

@@ -55,7 +55,7 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

View File

@@ -65,12 +65,14 @@
8. Start Dify [web](../web) service.
9. Setup your application by visiting `http://localhost:3000`...
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
10. If you need to debug local async processing, please start the worker service.
```bash
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
```
The started celery app handles the async tasks, e.g. dataset importing and documents indexing.
## Testing
1. Install dependencies for both the backend and the test environment

View File

@@ -164,7 +164,7 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
if request.blueprint not in ["console", "inner_api"]:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")

View File

@@ -28,28 +28,28 @@ from services.account_service import RegisterService, TenantService
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="Account email to reset password for")
@click.option("--new-password", prompt=True, help="New password")
@click.option("--password-confirm", prompt=True, help="Confirm new password")
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
@click.option("--new-password", prompt=True, help="the new password.")
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
def reset_password(email, new_password, password_confirm):
"""
Reset password of owner account
Only available in SELF_HOSTED mode
"""
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
return
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red"))
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
return
# generate password salt
@@ -62,37 +62,37 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
click.echo(click.style("Password reset successfully.", fg="green"))
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
@click.command("reset-email", help="Reset the account email.")
@click.option("--email", prompt=True, help="Current account email")
@click.option("--new-email", prompt=True, help="New email")
@click.option("--email-confirm", prompt=True, help="Confirm new email")
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
@click.option("--new-email", prompt=True, help="the new email.")
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
def reset_email(email, new_email, email_confirm):
"""
Replace account email
:return:
"""
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
return
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
email_validate(new_email)
except:
click.echo(click.style("Invalid email: {}".format(new_email), fg="red"))
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
@click.command(
@@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red"
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
)
)
def reset_encrypt_key_pair():
@@ -114,13 +114,13 @@ def reset_encrypt_key_pair():
Only support SELF_HOSTED mode.
"""
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
return
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
@@ -131,18 +131,18 @@ def reset_encrypt_key_pair():
click.echo(
click.style(
"Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id),
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green",
)
)
@click.command("vdb-migrate", help="Migrate vector db.")
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str):
if scope in {"knowledge", "all"}:
if scope in ["knowledge", "all"]:
migrate_knowledge_vector_database()
if scope in {"annotation", "all"}:
if scope in ["annotation", "all"]:
migrate_annotation_vector_database()
@@ -150,7 +150,7 @@ def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style("Starting annotation data migration.", fg="green"))
click.echo(click.style("Start migrate annotation data.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@@ -174,14 +174,14 @@ def migrate_annotation_vector_database():
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo("Creating app annotation index: {}".format(app.id))
click.echo("Create app annotation index: {}".format(app.id))
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo("App annotation setting disabled: {}".format(app.id))
click.echo("App annotation setting is disabled: {}".format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = (
@@ -190,7 +190,7 @@ def migrate_annotation_vector_database():
.first()
)
if not dataset_collection_binding:
click.echo("App annotation collection binding not found: {}".format(app.id))
click.echo("App annotation collection binding is not exist: {}".format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
@@ -211,11 +211,11 @@ def migrate_annotation_vector_database():
documents.append(document)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
click.echo(f"Migrating annotations for app: {app.id}.")
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green"))
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
raise e
@@ -223,12 +223,12 @@ def migrate_annotation_vector_database():
try:
click.echo(
click.style(
f"Creating vector index with {len(documents)} annotations for app {app.id}.",
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(click.style(f"Created vector index for app {app.id}.", fg="green"))
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
raise e
@@ -237,14 +237,14 @@ def migrate_annotation_vector_database():
except Exception as e:
click.echo(
click.style(
"Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)), fg="red"
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
)
)
continue
click.echo(
click.style(
f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.",
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
fg="green",
)
)
@@ -254,7 +254,7 @@ def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style("Starting vector database migration.", fg="green"))
click.echo(click.style("Start migrate vector db.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@@ -275,10 +275,11 @@ def migrate_knowledge_vector_database():
for dataset in datasets:
total_count = total_count + 1
click.echo(
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo("Creating dataset vector database index: {}".format(dataset.id))
click.echo("Create dataset vdb index: {}".format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict["type"] == vector_type:
skipped_count = skipped_count + 1
@@ -299,7 +300,7 @@ def migrate_knowledge_vector_database():
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError("Dataset Collection Binding not found")
raise ValueError("Dataset Collection Bindings is not exist!")
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
@@ -351,12 +352,14 @@ def migrate_knowledge_vector_database():
raise ValueError(f"Vector store {vector_type} is not supported.")
vector = Vector(dataset)
click.echo(f"Migrating dataset {dataset.id}.")
click.echo(f"Start to migrate dataset {dataset.id}.")
try:
vector.delete()
click.echo(
click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green")
click.style(
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
)
)
except Exception as e:
click.echo(
@@ -408,13 +411,14 @@ def migrate_knowledge_vector_database():
try:
click.echo(
click.style(
f"Creating vector index with {len(documents)} documents of {segments_count}"
f" segments for dataset {dataset.id}.",
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green"))
click.echo(
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
)
except Exception as e:
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
raise e
@@ -425,13 +429,13 @@ def migrate_knowledge_vector_database():
except Exception as e:
db.session.rollback()
click.echo(
click.style("Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)), fg="red")
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
)
continue
click.echo(
click.style(
f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green"
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
)
)
@@ -441,7 +445,7 @@ def convert_to_agent_apps():
"""
Convert Agent Assistant to Agent App.
"""
click.echo(click.style("Starting convert to agent apps.", fg="green"))
click.echo(click.style("Start convert to agent apps.", fg="green"))
proceeded_app_ids = []
@@ -492,23 +496,23 @@ def convert_to_agent_apps():
except Exception as e:
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
click.echo(click.style("Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
@click.command("add-qdrant-doc-id-index", help="Add Qdrant doc_id index.")
@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.")
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
def add_qdrant_doc_id_index(field: str):
click.echo(click.style("Starting Qdrant doc_id index creation.", fg="green"))
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
vector_type = dify_config.VECTOR_STORE
if vector_type != "qdrant":
click.echo(click.style("This command only supports Qdrant vector store.", fg="red"))
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
return
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings:
click.echo(click.style("No dataset collection bindings found.", fg="red"))
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
return
import qdrant_client
from qdrant_client.http.exceptions import UnexpectedResponse
@@ -518,7 +522,7 @@ def add_qdrant_doc_id_index(field: str):
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.")
raise ValueError("Qdrant url is required.")
qdrant_config = QdrantConfig(
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
@@ -535,39 +539,41 @@ def add_qdrant_doc_id_index(field: str):
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red"))
click.echo(
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
)
continue
# Some other error occurred, so re-raise the exception
else:
click.echo(
click.style(
f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red"
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
)
)
except Exception as e:
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
click.echo(click.style("Failed to create qdrant client.", fg="red"))
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="Tenant account email.")
@click.option("--name", prompt=True, help="Workspace name.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
"""
Create tenant account
"""
if not email:
click.echo(click.style("Email is required.", fg="red"))
click.echo(click.style("Sorry, email is required.", fg="red"))
return
# Create account
email = email.strip()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
click.echo(click.style("Sorry, invalid email address.", fg="red"))
return
account_name = email.split("@")[0]
@@ -587,19 +593,19 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
click.echo(
click.style(
"Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)
@click.command("upgrade-db", help="Upgrade the database")
@click.command("upgrade-db", help="upgrade the database")
def upgrade_db():
click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False):
try:
click.echo(click.style("Starting database migration.", fg="green"))
click.echo(click.style("Start database migration.", fg="green"))
# run db migration
import flask_migrate
@@ -609,7 +615,7 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
logging.exception(f"Database migration failed: {e}")
logging.exception(f"Database migration failed, error: {e}")
finally:
lock.release()
else:
@@ -621,7 +627,7 @@ def fix_app_site_missing():
"""
Fix app related site missing issue.
"""
click.echo(click.style("Starting fix for missing app-related sites.", fg="green"))
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
failed_app_ids = []
while True:
@@ -644,22 +650,22 @@ where sites.id is null limit 1000"""
if tenant:
accounts = tenant.get_accounts()
if not accounts:
print("Fix failed for app {}".format(app.id))
print("Fix app {} failed.".format(app.id))
continue
account = accounts[0]
print("Fixing missing site for app {}".format(app.id))
print("Fix app {} related site missing issue.".format(app.id))
app_was_created.send(app, account=account)
except Exception as e:
failed_app_ids.append(app_id)
click.echo(click.style("FFailed to fix missing site for app {}".format(app_id), fg="red"))
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
logging.exception(f"Fix app related site missing issue failed, error: {e}")
continue
if not processed_count:
break
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
def register_commands(app):

View File

@@ -4,30 +4,30 @@ from pydantic_settings import BaseSettings
class DeploymentConfig(BaseSettings):
"""
Configuration settings for application deployment
Deployment configs
"""
APPLICATION_NAME: str = Field(
description="Name of the application, used for identification and logging purposes",
description="application name",
default="langgenius/dify",
)
DEBUG: bool = Field(
description="Enable debug mode for additional logging and development features",
description="whether to enable debug mode.",
default=False,
)
TESTING: bool = Field(
description="Enable testing mode for running automated tests",
description="",
default=False,
)
EDITION: str = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
description="deployment edition",
default="SELF_HOSTED",
)
DEPLOY_ENV: str = Field(
description="Deployment environment (e.g., 'PRODUCTION', 'DEVELOPMENT'), default to PRODUCTION",
description="deployment environment, default to PRODUCTION.",
default="PRODUCTION",
)

View File

@@ -4,17 +4,17 @@ from pydantic_settings import BaseSettings
class EnterpriseFeatureConfig(BaseSettings):
"""
Configuration for enterprise-level features.
Enterprise feature configs.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
"""
ENTERPRISE_ENABLED: bool = Field(
description="Enable or disable enterprise-level features."
description="whether to enable enterprise features."
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
default=False,
)
CAN_REPLACE_LOGO: bool = Field(
description="Allow customization of the enterprise logo.",
description="whether to allow replacing enterprise logo.",
default=False,
)

View File

@@ -6,31 +6,30 @@ from pydantic_settings import BaseSettings
class NotionConfig(BaseSettings):
"""
Configuration settings for Notion integration
Notion integration configs
"""
NOTION_CLIENT_ID: Optional[str] = Field(
description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.",
description="Notion client ID",
default=None,
)
NOTION_CLIENT_SECRET: Optional[str] = Field(
description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.",
description="Notion client secret key",
default=None,
)
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
description="Type of Notion integration."
" Set to 'internal' for internal integrations, or None for public integrations.",
description="Notion integration type, default to None, available values: internal.",
default=None,
)
NOTION_INTERNAL_SECRET: Optional[str] = Field(
description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.",
description="Notion internal secret key",
default=None,
)
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
description="Integration token for Notion API access. Used for direct API calls without OAuth flow.",
description="Notion integration token",
default=None,
)

View File

@@ -6,23 +6,20 @@ from pydantic_settings import BaseSettings
class SentryConfig(BaseSettings):
"""
Configuration settings for Sentry error tracking and performance monitoring
Sentry configs
"""
SENTRY_DSN: Optional[str] = Field(
description="Sentry Data Source Name (DSN)."
" This is the unique identifier of your Sentry project, used to send events to the correct project.",
description="Sentry DSN",
default=None,
)
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sample rate for Sentry performance monitoring traces."
" Value between 0.0 and 1.0, where 1.0 means 100% of traces are sent to Sentry.",
description="Sentry trace sample rate",
default=1.0,
)
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sample rate for Sentry profiling."
" Value between 0.0 and 1.0, where 1.0 means 100% of profiles are sent to Sentry.",
description="Sentry profiles sample rate",
default=1.0,
)

View File

@@ -8,143 +8,145 @@ from configs.feature.hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings):
"""
Security-related configurations for the application
Secret Key configs
"""
SECRET_KEY: Optional[str] = Field(
description="Secret key for secure session cookie signing."
description="Your App secret key will be used for securely signing the session cookie"
"Make sure you are changing this key for your deployment with a strong key."
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
"You can generate a strong key using `openssl rand -base64 42`."
"Alternatively you can set it with `SECRET_KEY` environment variable.",
default=None,
)
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description="Duration in hours for which a password reset token remains valid",
description="Expiry time in hours for reset token",
default=24,
)
class AppExecutionConfig(BaseSettings):
"""
Configuration parameters for application execution
App Execution configs
"""
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
description="Maximum allowed execution time for the application in seconds",
description="execution timeout in seconds for app execution",
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
description="max active request per app, 0 means unlimited",
default=0,
)
class CodeExecutionSandboxConfig(BaseSettings):
"""
Configuration for the code execution sandbox environment
Code Execution Sandbox configs
"""
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="URL endpoint for the code execution service",
description="endpoint URL of code execution servcie",
default="http://sandbox:8194",
)
CODE_EXECUTION_API_KEY: str = Field(
description="API key for accessing the code execution service",
description="API key for code execution service",
default="dify-sandbox",
)
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
description="Connection timeout in seconds for code execution requests",
description="connect timeout in seconds for code execution request",
default=10.0,
)
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
description="Read timeout in seconds for code execution requests",
description="read timeout in seconds for code execution request",
default=60.0,
)
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
description="Write timeout in seconds for code execution request",
description="write timeout in seconds for code execution request",
default=10.0,
)
CODE_MAX_NUMBER: PositiveInt = Field(
description="Maximum allowed numeric value in code execution",
description="max depth for code execution",
default=9223372036854775807,
)
CODE_MIN_NUMBER: NegativeInt = Field(
description="Minimum allowed numeric value in code execution",
description="",
default=-9223372036854775807,
)
CODE_MAX_DEPTH: PositiveInt = Field(
description="Maximum allowed depth for nested structures in code execution",
description="max depth for code execution",
default=5,
)
CODE_MAX_PRECISION: PositiveInt = Field(
description="mMaximum number of decimal places for floating-point numbers in code execution",
description="max precision digits for float type in code execution",
default=20,
)
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution",
description="max string length for code execution",
default=80000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for string arrays in code execution",
description="",
default=30,
)
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for object arrays in code execution",
description="",
default=30,
)
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for numeric arrays in code execution",
description="",
default=1000,
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
Module URL configs
"""
CONSOLE_API_URL: str = Field(
description="Base URL for the console API,"
"used for login authentication callback or notion integration callbacks",
description="The backend URL prefix of the console API."
"used to concatenate the login authorization callback or notion integration callback.",
default="",
)
CONSOLE_WEB_URL: str = Field(
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
description="The front-end URL prefix of the console web."
"used to concatenate some front-end addresses and for CORS configuration use.",
default="",
)
SERVICE_API_URL: str = Field(
description="Base URL for the service API, displayed to users for API access",
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
default="",
)
APP_WEB_URL: str = Field(
description="Base URL for the web application, used for frontend references",
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
default="",
)
class FileAccessConfig(BaseSettings):
"""
Configuration for file access and handling
File Access configs
"""
FILES_URL: str = Field(
description="Base URL for file preview or download,"
" used for frontend display and multi-model inputs"
description="File preview or download Url prefix."
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
"Url is signed and has expiration time.",
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
alias_priority=1,
@@ -152,49 +154,49 @@ class FileAccessConfig(BaseSettings):
)
FILES_ACCESS_TIMEOUT: int = Field(
description="Expiration time in seconds for file access URLs",
description="timeout in seconds for file accessing",
default=300,
)
class FileUploadConfig(BaseSettings):
"""
Configuration for file upload limitations
File Uploading configs
"""
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed file size for uploads in megabytes",
description="size limit in Megabytes for uploading files",
default=15,
)
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
description="Maximum number of files allowed in a single upload batch",
description="batch size limit for uploading files",
default=5,
)
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for uploads in megabytes",
description="image file size limit in Megabytes for uploading files",
default=10,
)
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description="Maximum number of files allowed in a batch upload operation",
description="", # todo: to be clarified
default=20,
)
class HttpConfig(BaseSettings):
"""
HTTP-related configurations for the application
HTTP configs
"""
API_COMPRESSION_ENABLED: bool = Field(
description="Enable or disable gzip compression for HTTP responses",
description="whether to enable HTTP response compression of gzip",
default=False,
)
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
description="Comma-separated list of allowed origins for CORS in the console",
description="",
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
default="",
)
@@ -216,360 +218,359 @@ class HttpConfig(BaseSettings):
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests")
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
] = 10
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests")
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
] = 60
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests")
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
] = 20
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="Maximum allowed size in bytes for binary data in HTTP requests",
description="",
default=10 * 1024 * 1024,
)
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
description="Maximum allowed size in bytes for text data in HTTP requests",
description="",
default=1 * 1024 * 1024,
)
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)",
description="HTTP URL for SSRF proxy",
default=None,
)
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)",
description="HTTPS URL for SSRF proxy",
default=None,
)
class InnerAPIConfig(BaseSettings):
"""
Configuration for internal API functionality
Inner API configs
"""
INNER_API: bool = Field(
description="Enable or disable the internal API",
description="whether to enable the inner API",
default=False,
)
INNER_API_KEY: Optional[str] = Field(
description="API key for accessing the internal API",
description="The inner API key is used to authenticate the inner API",
default=None,
)
class LoggingConfig(BaseSettings):
"""
Configuration for application logging
Logging configs
"""
LOG_LEVEL: str = Field(
description="Logging level, default to INFO. Set to ERROR for production environments.",
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
default="INFO",
)
LOG_FILE: Optional[str] = Field(
description="File path for log output.",
description="logging output file path",
default=None,
)
LOG_FORMAT: str = Field(
description="Format string for log messages",
description="log format",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
)
LOG_DATEFORMAT: Optional[str] = Field(
description="Date format string for log timestamps",
description="log date format",
default=None,
)
LOG_TZ: Optional[str] = Field(
description="Timezone for log timestamps (e.g., 'America/New_York')",
description="specify log timezone, eg: America/New_York",
default=None,
)
class ModelLoadBalanceConfig(BaseSettings):
"""
Configuration for model load balancing
Model load balance configs
"""
MODEL_LB_ENABLED: bool = Field(
description="Enable or disable load balancing for models",
description="whether to enable model load balancing",
default=False,
)
class BillingConfig(BaseSettings):
"""
Configuration for platform billing features
Platform Billing Configurations
"""
BILLING_ENABLED: bool = Field(
description="Enable or disable billing functionality",
description="whether to enable billing",
default=False,
)
class UpdateConfig(BaseSettings):
"""
Configuration for application update checks
Update configs
"""
CHECK_UPDATE_URL: str = Field(
description="URL to check for application updates",
description="url for checking updates",
default="https://updates.dify.ai",
)
class WorkflowConfig(BaseSettings):
"""
Configuration for workflow execution
Workflow feature configs
"""
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
description="Maximum number of steps allowed in a single workflow execution",
description="max execution steps in single workflow execution",
default=500,
)
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
description="Maximum execution time in seconds for a single workflow",
description="max execution time in seconds in single workflow execution",
default=1200,
)
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
description="Maximum allowed depth for nested workflow calls",
description="max depth of calling in single workflow execution",
default=5,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 5KB.",
description="The maximum size in bytes of a variable. default to 5KB.",
default=5 * 1024,
)
class OAuthConfig(BaseSettings):
"""
Configuration for OAuth authentication
oauth configs
"""
OAUTH_REDIRECT_PATH: str = Field(
description="Redirect path for OAuth authentication callbacks",
description="redirect path for OAuth",
default="/console/api/oauth/authorize",
)
GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub OAuth client secret",
description="GitHub client id for OAuth",
default=None,
)
GITHUB_CLIENT_SECRET: Optional[str] = Field(
description="GitHub OAuth client secret",
description="GitHub client secret key for OAuth",
default=None,
)
GOOGLE_CLIENT_ID: Optional[str] = Field(
description="Google OAuth client ID",
description="Google client id for OAuth",
default=None,
)
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
description="Google OAuth client secret",
description="Google client secret key for OAuth",
default=None,
)
class ModerationConfig(BaseSettings):
"""
Configuration for content moderation
Moderation in app configs.
"""
MODERATION_BUFFER_SIZE: PositiveInt = Field(
description="Size of the buffer for content moderation processing",
description="buffer size for moderation",
default=300,
)
class ToolConfig(BaseSettings):
"""
Configuration for tool management
Tool configs
"""
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
description="Maximum age in seconds for caching tool icons",
description="max age in seconds for tool icon caching",
default=3600,
)
class MailConfig(BaseSettings):
"""
Configuration for email services
Mail Configurations
"""
MAIL_TYPE: Optional[str] = Field(
description="Email service provider type ('smtp' or 'resend'), default to None.",
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
default=None,
)
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
description="Default email address to use as the sender",
description="default email address for sending from ",
default=None,
)
RESEND_API_KEY: Optional[str] = Field(
description="API key for Resend email service",
description="API key for Resend",
default=None,
)
RESEND_API_URL: Optional[str] = Field(
description="API URL for Resend email service",
description="API URL for Resend",
default=None,
)
SMTP_SERVER: Optional[str] = Field(
description="SMTP server hostname",
description="smtp server host",
default=None,
)
SMTP_PORT: Optional[int] = Field(
description="SMTP server port number",
description="smtp server port",
default=465,
)
SMTP_USERNAME: Optional[str] = Field(
description="Username for SMTP authentication",
description="smtp server username",
default=None,
)
SMTP_PASSWORD: Optional[str] = Field(
description="Password for SMTP authentication",
description="smtp server password",
default=None,
)
SMTP_USE_TLS: bool = Field(
description="Enable TLS encryption for SMTP connections",
description="whether to use TLS connection to smtp server",
default=False,
)
SMTP_OPPORTUNISTIC_TLS: bool = Field(
description="Enable opportunistic TLS for SMTP connections",
description="whether to use opportunistic TLS connection to smtp server",
default=False,
)
class RagEtlConfig(BaseSettings):
"""
Configuration for RAG ETL processes
RAG ETL Configurations.
"""
ETL_TYPE: str = Field(
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
default="dify",
)
KEYWORD_DATA_SOURCE_TYPE: str = Field(
description="Data source type for keyword extraction"
" ('database' or other supported types), default to 'database'",
description="source type for keyword data, default to `database`, available values are `database` .",
default="database",
)
UNSTRUCTURED_API_URL: Optional[str] = Field(
description="API URL for Unstructured.io service",
description="API URL for Unstructured",
default=None,
)
UNSTRUCTURED_API_KEY: Optional[str] = Field(
description="API key for Unstructured.io service",
description="API key for Unstructured",
default=None,
)
class DataSetConfig(BaseSettings):
"""
Configuration for dataset management
Dataset configs
"""
CLEAN_DAY_SETTING: PositiveInt = Field(
description="Interval in days for dataset cleanup operations",
description="interval in days for cleaning up dataset",
default=30,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description="Enable or disable dataset operator functionality",
description="whether to enable dataset operator",
default=False,
)
class WorkspaceConfig(BaseSettings):
"""
Configuration for workspace management
Workspace configs
"""
INVITE_EXPIRY_HOURS: PositiveInt = Field(
description="Expiration time in hours for workspace invitation links",
description="workspaces invitation expiration in hours",
default=72,
)
class IndexingConfig(BaseSettings):
"""
Configuration for indexing operations
Indexing configs.
"""
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
description="Maximum token length for text segmentation during indexing",
description="max segmentation token length for indexing",
default=1000,
)
class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
description="multi model send image format, support base64, url, default is base64",
default="base64",
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
description="the time of the celery scheduler, default to 1 day",
default=1,
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="Comma-separated list of pinned model providers",
description="The heads of model providers",
default="",
)
POSITION_PROVIDER_INCLUDES: str = Field(
description="Comma-separated list of included model providers",
description="The included model providers",
default="",
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description="Comma-separated list of excluded model providers",
description="The excluded model providers",
default="",
)
POSITION_TOOL_PINS: str = Field(
description="Comma-separated list of pinned tools",
description="The heads of tools",
default="",
)
POSITION_TOOL_INCLUDES: str = Field(
description="Comma-separated list of included tools",
description="The included tools",
default="",
)
POSITION_TOOL_EXCLUDES: str = Field(
description="Comma-separated list of excluded tools",
description="The excluded tools",
default="",
)

View File

@@ -6,31 +6,31 @@ from pydantic_settings import BaseSettings
class HostedOpenAiConfig(BaseSettings):
"""
Configuration for hosted OpenAI service
Hosted OpenAI service config
"""
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
description="API key for hosted OpenAI service",
description="",
default=None,
)
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
description="Base URL for hosted OpenAI API",
description="",
default=None,
)
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
description="Organization ID for hosted OpenAI service",
description="",
default=None,
)
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted OpenAI service",
description="",
default=False,
)
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
description="",
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
@@ -42,17 +42,17 @@ class HostedOpenAiConfig(BaseSettings):
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
description="",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted OpenAI service",
description="",
default=False,
)
HOSTED_OPENAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
description="",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
@@ -71,122 +71,124 @@ class HostedOpenAiConfig(BaseSettings):
class HostedAzureOpenAiConfig(BaseSettings):
"""
Configuration for hosted Azure OpenAI service
Hosted OpenAI service config
"""
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
description="Enable hosted Azure OpenAI service",
description="",
default=False,
)
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description="API key for hosted Azure OpenAI service",
description="",
default=None,
)
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
description="Base URL for hosted Azure OpenAI API",
description="",
default=None,
)
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Azure OpenAI service usage",
description="",
default=200,
)
class HostedAnthropicConfig(BaseSettings):
"""
Configuration for hosted Anthropic service
Hosted Azure OpenAI service config
"""
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
description="Base URL for hosted Anthropic API",
description="",
default=None,
)
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
description="API key for hosted Anthropic service",
description="",
default=None,
)
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Anthropic service",
description="",
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
description="",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
description="",
default=False,
)
class HostedMinmaxConfig(BaseSettings):
"""
Configuration for hosted Minmax service
Hosted Minmax service config
"""
HOSTED_MINIMAX_ENABLED: bool = Field(
description="Enable hosted Minmax service",
description="",
default=False,
)
class HostedSparkConfig(BaseSettings):
"""
Configuration for hosted Spark service
Hosted Spark service config
"""
HOSTED_SPARK_ENABLED: bool = Field(
description="Enable hosted Spark service",
description="",
default=False,
)
class HostedZhipuAIConfig(BaseSettings):
"""
Configuration for hosted ZhipuAI service
Hosted Minmax service config
"""
HOSTED_ZHIPUAI_ENABLED: bool = Field(
description="Enable hosted ZhipuAI service",
description="",
default=False,
)
class HostedModerationConfig(BaseSettings):
"""
Configuration for hosted Moderation service
Hosted Moderation service config
"""
HOSTED_MODERATION_ENABLED: bool = Field(
description="Enable hosted Moderation service",
description="",
default=False,
)
HOSTED_MODERATION_PROVIDERS: str = Field(
description="Comma-separated list of moderation providers",
description="",
default="",
)
class HostedFetchAppTemplateConfig(BaseSettings):
"""
Configuration for fetching app templates
Hosted Moderation service config
"""
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
description="the mode for fetching app templates,"
" default to remote,"
" available values: remote, db, builtin",
default="remote",
)
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="Domain for fetching remote app templates",
description="the domain for fetching remote app templates",
default="https://tmpl.dify.ai",
)

View File

@@ -1,7 +1,7 @@
from typing import Any, Optional
from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
@@ -9,10 +9,8 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
@@ -31,71 +29,70 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
description="Type of storage to use."
" Options: 'local', 's3', 'azure-blob', 'aliyun-oss', 'google-storage'. Default is 'local'.",
description="storage type,"
" default to `local`,"
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
default="local",
)
STORAGE_LOCAL_PATH: str = Field(
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
description="local storage path",
default="storage",
)
class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field(
description="Type of vector store to use for efficient similarity search."
" Set to None if not using a vector store.",
description="vector store type",
default=None,
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
description="Method for keyword extraction and storage."
" Default is 'jieba', a Chinese text segmentation library.",
description="keyword store type",
default="jieba",
)
class DatabaseConfig:
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
description="db host",
default="localhost",
)
DB_PORT: PositiveInt = Field(
description="Port number for database connection.",
description="db port",
default=5432,
)
DB_USERNAME: str = Field(
description="Username for database authentication.",
description="db username",
default="postgres",
)
DB_PASSWORD: str = Field(
description="Password for database authentication.",
description="db password",
default="",
)
DB_DATABASE: str = Field(
description="Name of the database to connect to.",
description="db database",
default="dify",
)
DB_CHARSET: str = Field(
description="Character set for database connection.",
description="db charset",
default="",
)
DB_EXTRAS: str = Field(
description="Additional database connection parameters. Example: 'keepalives_idle=60&keepalives=1'",
description="db extras options. Example: keepalives_idle=60&keepalives=1",
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="Database URI scheme for SQLAlchemy connection.",
description="db uri scheme",
default="postgresql",
)
@@ -113,27 +110,27 @@ class DatabaseConfig:
)
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
description="Maximum number of database connections in the pool.",
description="pool size of SqlAlchemy",
default=30,
)
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
description="Maximum number of connections that can be created beyond the pool_size.",
description="max overflows for SqlAlchemy",
default=10,
)
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
description="Number of seconds after which a connection is automatically recycled.",
description="SqlAlchemy pool recycle",
default=3600,
)
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="If True, enables connection pool pre-ping feature to check connections.",
description="whether to enable pool pre-ping in SqlAlchemy",
default=False,
)
SQLALCHEMY_ECHO: bool | str = Field(
description="If True, SQLAlchemy will log all SQL statements.",
description="whether to enable SqlAlchemy echo",
default=False,
)
@@ -151,30 +148,15 @@ class DatabaseConfig:
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
description="Celery backend, available values are `database`, `redis`",
default="database",
)
CELERY_BROKER_URL: Optional[str] = Field(
description="URL of the message broker for Celery tasks.",
description="CELERY_BROKER_URL",
default=None,
)
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel for high availability.",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Name of the Redis Sentinel master.",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Timeout for Redis Sentinel socket operations in seconds.",
default=0.1,
)
@computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str | None:
@@ -202,8 +184,6 @@ class MiddlewareConfig(
AzureBlobStorageConfig,
GoogleCloudStorageConfig,
TencentCloudCOSStorageConfig,
HuaweiCloudOBSStorageConfig,
VolcengineTOSStorageConfig,
S3StorageConfig,
OCIStorageConfig,
# configs of vdb and vdb providers

View File

@@ -1,70 +1,40 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class RedisConfig(BaseSettings):
"""
Configuration settings for Redis connection
Redis configs
"""
REDIS_HOST: str = Field(
description="Hostname or IP address of the Redis server",
description="Redis host",
default="localhost",
)
REDIS_PORT: PositiveInt = Field(
description="Port number on which the Redis server is listening",
description="Redis port",
default=6379,
)
REDIS_USERNAME: Optional[str] = Field(
description="Username for Redis authentication (if required)",
description="Redis username",
default=None,
)
REDIS_PASSWORD: Optional[str] = Field(
description="Password for Redis authentication (if required)",
description="Redis password",
default=None,
)
REDIS_DB: NonNegativeInt = Field(
description="Redis database number to use (0-15)",
description="Redis database id, default to 0",
default=0,
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
description="whether to use SSL for Redis connection",
default=False,
)
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Enable Redis Sentinel mode for high availability",
default=False,
)
REDIS_SENTINELS: Optional[str] = Field(
description="Comma-separated list of Redis Sentinel nodes (host:port)",
default=None,
)
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
description="Name of the Redis Sentinel service to monitor",
default=None,
)
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
description="Username for Redis Sentinel authentication (if required)",
default=None,
)
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
description="Password for Redis Sentinel authentication (if required)",
default=None,
)
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Socket timeout in seconds for Redis Sentinel connections",
default=0.1,
)

View File

@@ -6,40 +6,40 @@ from pydantic_settings import BaseSettings
class AliyunOSSStorageConfig(BaseSettings):
"""
Configuration settings for Aliyun Object Storage Service (OSS)
Aliyun storage configs
"""
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Aliyun OSS bucket to store and retrieve objects",
description="Aliyun OSS bucket name",
default=None,
)
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
description="Access key ID for authenticating with Aliyun OSS",
description="Aliyun OSS access key",
default=None,
)
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
description="Secret access key for authenticating with Aliyun OSS",
description="Aliyun OSS secret key",
default=None,
)
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
description="URL of the Aliyun OSS endpoint for your chosen region",
description="Aliyun OSS endpoint URL",
default=None,
)
ALIYUN_OSS_REGION: Optional[str] = Field(
description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')",
description="Aliyun OSS region",
default=None,
)
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')",
description="Aliyun OSS authentication version",
default=None,
)
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
description="Aliyun OSS path",
default=None,
)

View File

@@ -6,40 +6,40 @@ from pydantic_settings import BaseSettings
class S3StorageConfig(BaseSettings):
"""
Configuration settings for S3-compatible object storage
S3 storage configs
"""
S3_ENDPOINT: Optional[str] = Field(
description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')",
description="S3 storage endpoint",
default=None,
)
S3_REGION: Optional[str] = Field(
description="Region where the S3 bucket is located (e.g., 'us-east-1')",
description="S3 storage region",
default=None,
)
S3_BUCKET_NAME: Optional[str] = Field(
description="Name of the S3 bucket to store and retrieve objects",
description="S3 storage bucket name",
default=None,
)
S3_ACCESS_KEY: Optional[str] = Field(
description="Access key ID for authenticating with the S3 service",
description="S3 storage access key",
default=None,
)
S3_SECRET_KEY: Optional[str] = Field(
description="Secret access key for authenticating with the S3 service",
description="S3 storage secret key",
default=None,
)
S3_ADDRESS_STYLE: str = Field(
description="S3 addressing style: 'auto', 'path', or 'virtual'",
description="S3 storage address style",
default="auto",
)
S3_USE_AWS_MANAGED_IAM: bool = Field(
description="Use AWS managed IAM roles for authentication instead of access/secret keys",
description="whether to use aws managed IAM for S3",
default=False,
)

View File

@@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
class AzureBlobStorageConfig(BaseSettings):
"""
Configuration settings for Azure Blob Storage
Azure Blob storage configs
"""
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
description="Name of the Azure Storage account (e.g., 'mystorageaccount')",
description="Azure Blob account name",
default=None,
)
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
description="Access key for authenticating with the Azure Storage account",
description="Azure Blob account key",
default=None,
)
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
description="Name of the Azure Blob container to store and retrieve objects",
description="Azure Blob container name",
default=None,
)
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')",
description="Azure Blob account URL",
default=None,
)

View File

@@ -6,15 +6,15 @@ from pydantic_settings import BaseSettings
class GoogleCloudStorageConfig(BaseSettings):
"""
Configuration settings for Google Cloud Storage
Google Cloud storage configs
"""
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')",
description="Google Cloud storage bucket name",
default=None,
)
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
description="Base64-encoded JSON key file for Google Cloud service account authentication",
description="Google Cloud storage service account json base64",
default=None,
)

View File

@@ -1,29 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class HuaweiCloudOBSStorageConfig(BaseModel):
"""
Configuration settings for Huawei Cloud Object Storage Service (OBS)
"""
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')",
default=None,
)
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
description="Access Key ID for authenticating with Huawei Cloud OBS",
default=None,
)
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
description="Secret Access Key for authenticating with Huawei Cloud OBS",
default=None,
)
HUAWEI_OBS_SERVER: Optional[str] = Field(
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
default=None,
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class OCIStorageConfig(BaseSettings):
"""
Configuration settings for Oracle Cloud Infrastructure (OCI) Object Storage
OCI storage configs
"""
OCI_ENDPOINT: Optional[str] = Field(
description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')",
description="OCI storage endpoint",
default=None,
)
OCI_REGION: Optional[str] = Field(
description="OCI region where the bucket is located (e.g., 'us-phoenix-1')",
description="OCI storage region",
default=None,
)
OCI_BUCKET_NAME: Optional[str] = Field(
description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')",
description="OCI storage bucket name",
default=None,
)
OCI_ACCESS_KEY: Optional[str] = Field(
description="Access key (also known as API key) for authenticating with OCI Object Storage",
description="OCI storage access key",
default=None,
)
OCI_SECRET_KEY: Optional[str] = Field(
description="Secret key associated with the access key for authenticating with OCI Object Storage",
description="OCI storage secret key",
default=None,
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class TencentCloudCOSStorageConfig(BaseSettings):
"""
Configuration settings for Tencent Cloud Object Storage (COS)
Tencent Cloud COS storage configs
"""
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Tencent Cloud COS bucket to store and retrieve objects",
description="Tencent Cloud COS bucket name",
default=None,
)
TENCENT_COS_REGION: Optional[str] = Field(
description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')",
description="Tencent Cloud COS region",
default=None,
)
TENCENT_COS_SECRET_ID: Optional[str] = Field(
description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)",
description="Tencent Cloud COS secret id",
default=None,
)
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)",
description="Tencent Cloud COS secret key",
default=None,
)
TENCENT_COS_SCHEME: Optional[str] = Field(
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
description="Tencent Cloud COS scheme",
default=None,
)

View File

@@ -1,34 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class VolcengineTOSStorageConfig(BaseModel):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')",
default=None,
)
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
description="Access Key ID for authenticating with Volcengine TOS",
default=None,
)
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
description="Secret Access Key for authenticating with Volcengine TOS",
default=None,
)
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')",
default=None,
)
VOLCENGINE_TOS_REGION: Optional[str] = Field(
description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')",
default=None,
)

View File

@@ -5,38 +5,33 @@ from pydantic import BaseModel, Field
class AnalyticdbConfig(BaseModel):
"""
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
Configuration for connecting to AnalyticDB.
Refer to the following documentation for details on obtaining credentials:
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID: Optional[str] = Field(
default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication."
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access."
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID: Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').",
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to.",
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
)
ANALYTICDB_ACCOUNT: Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance"
" (usually the initial account created with the instance).",
default=None, description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD: Optional[str] = Field(
default=None, description="The password associated with the AnalyticDB account for database authentication."
default=None, description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE: Optional[str] = Field(
default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)."
default=None, description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance"
" (if namespace feature is enabled).",
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
)

View File

@@ -6,35 +6,35 @@ from pydantic_settings import BaseSettings
class ChromaConfig(BaseSettings):
"""
Configuration settings for Chroma vector database
Chroma configs
"""
CHROMA_HOST: Optional[str] = Field(
description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')",
description="Chroma host",
default=None,
)
CHROMA_PORT: PositiveInt = Field(
description="Port number on which the Chroma server is listening (default is 8000)",
description="Chroma port",
default=8000,
)
CHROMA_TENANT: Optional[str] = Field(
description="Tenant identifier for multi-tenancy support in Chroma",
description="Chroma database",
default=None,
)
CHROMA_DATABASE: Optional[str] = Field(
description="Name of the Chroma database to connect to",
description="Chroma database",
default=None,
)
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)",
description="Chroma authentication provider",
default=None,
)
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
description="Authentication credentials for Chroma (format depends on the auth provider)",
description="Chroma authentication credentials",
default=None,
)

View File

@@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
class ElasticsearchConfig(BaseSettings):
"""
Configuration settings for Elasticsearch
Elasticsearch configs
"""
ELASTICSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')",
description="Elasticsearch host",
default="127.0.0.1",
)
ELASTICSEARCH_PORT: PositiveInt = Field(
description="Port number on which the Elasticsearch server is listening (default is 9200)",
description="Elasticsearch port",
default=9200,
)
ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Username for authenticating with Elasticsearch (default is 'elastic')",
description="Elasticsearch username",
default="elastic",
)
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Elasticsearch (default is 'elastic')",
description="Elasticsearch password",
default="elastic",
)

View File

@@ -1,35 +1,40 @@
from typing import Optional
from pydantic import Field
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class MilvusConfig(BaseSettings):
"""
Configuration settings for Milvus vector database
Milvus configs
"""
MILVUS_URI: Optional[str] = Field(
description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')",
default="http://127.0.0.1:19530",
)
MILVUS_TOKEN: Optional[str] = Field(
description="Authentication token for Milvus, if token-based authentication is enabled",
MILVUS_HOST: Optional[str] = Field(
description="Milvus host",
default=None,
)
MILVUS_PORT: PositiveInt = Field(
description="Milvus RestFul API port",
default=9091,
)
MILVUS_USER: Optional[str] = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
description="Milvus user",
default=None,
)
MILVUS_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Milvus, if username/password authentication is enabled",
description="Milvus password",
default=None,
)
MILVUS_SECURE: bool = Field(
description="whether to use SSL connection for Milvus",
default=False,
)
MILVUS_DATABASE: str = Field(
description="Name of the Milvus database to connect to (default is 'default')",
description="Milvus database, default to `default`",
default="default",
)

View File

@@ -3,35 +3,35 @@ from pydantic import BaseModel, Field, PositiveInt
class MyScaleConfig(BaseModel):
"""
Configuration settings for MyScale vector database
MyScale configs
"""
MYSCALE_HOST: str = Field(
description="Hostname or IP address of the MyScale server (e.g., 'localhost' or 'myscale.example.com')",
description="MyScale host",
default="localhost",
)
MYSCALE_PORT: PositiveInt = Field(
description="Port number on which the MyScale server is listening (default is 8123)",
description="MyScale port",
default=8123,
)
MYSCALE_USER: str = Field(
description="Username for authenticating with MyScale (default is 'default')",
description="MyScale user",
default="default",
)
MYSCALE_PASSWORD: str = Field(
description="Password for authenticating with MyScale (default is an empty string)",
description="MyScale password",
default="",
)
MYSCALE_DATABASE: str = Field(
description="Name of the MyScale database to connect to (default is 'default')",
description="MyScale database name",
default="default",
)
MYSCALE_FTS_PARAMS: str = Field(
description="Additional parameters for MyScale Full Text Search index)",
description="MyScale fts index parameters",
default="",
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class OpenSearchConfig(BaseSettings):
"""
Configuration settings for OpenSearch
OpenSearch configs
"""
OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
description="OpenSearch host",
default=None,
)
OPENSEARCH_PORT: PositiveInt = Field(
description="Port number on which the OpenSearch server is listening (default is 9200)",
description="OpenSearch port",
default=9200,
)
OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch",
description="OpenSearch user",
default=None,
)
OPENSEARCH_PASSWORD: Optional[str] = Field(
description="Password for authenticating with OpenSearch",
description="OpenSearch password",
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
description="whether to use SSL connection for OpenSearch",
default=False,
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class OracleConfig(BaseSettings):
"""
Configuration settings for Oracle database
ORACLE configs
"""
ORACLE_HOST: Optional[str] = Field(
description="Hostname or IP address of the Oracle database server (e.g., 'localhost' or 'oracle.example.com')",
description="ORACLE host",
default=None,
)
ORACLE_PORT: Optional[PositiveInt] = Field(
description="Port number on which the Oracle database server is listening (default is 1521)",
description="ORACLE port",
default=1521,
)
ORACLE_USER: Optional[str] = Field(
description="Username for authenticating with the Oracle database",
description="ORACLE user",
default=None,
)
ORACLE_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Oracle database",
description="ORACLE password",
default=None,
)
ORACLE_DATABASE: Optional[str] = Field(
description="Name of the Oracle database or service to connect to (e.g., 'ORCL' or 'pdborcl')",
description="ORACLE database",
default=None,
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class PGVectorConfig(BaseSettings):
"""
Configuration settings for PGVector (PostgreSQL with vector extension)
PGVector configs
"""
PGVECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')",
description="PGVector host",
default=None,
)
PGVECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the PostgreSQL server is listening (default is 5433)",
description="PGVector port",
default=5433,
)
PGVECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the PostgreSQL database",
description="PGVector user",
default=None,
)
PGVECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the PostgreSQL database",
description="PGVector password",
default=None,
)
PGVECTOR_DATABASE: Optional[str] = Field(
description="Name of the PostgreSQL database to connect to",
description="PGVector database",
default=None,
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class PGVectoRSConfig(BaseSettings):
"""
Configuration settings for PGVecto.RS (Rust-based vector extension for PostgreSQL)
PGVectoRS configs
"""
PGVECTO_RS_HOST: Optional[str] = Field(
description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')",
description="PGVectoRS host",
default=None,
)
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
description="PGVectoRS port",
default=5431,
)
PGVECTO_RS_USER: Optional[str] = Field(
description="Username for authenticating with the PostgreSQL database using PGVecto.RS",
description="PGVectoRS user",
default=None,
)
PGVECTO_RS_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the PostgreSQL database using PGVecto.RS",
description="PGVectoRS password",
default=None,
)
PGVECTO_RS_DATABASE: Optional[str] = Field(
description="Name of the PostgreSQL database with PGVecto.RS extension to connect to",
description="PGVectoRS database",
default=None,
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class QdrantConfig(BaseSettings):
"""
Configuration settings for Qdrant vector database
Qdrant configs
"""
QDRANT_URL: Optional[str] = Field(
description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')",
description="Qdrant url",
default=None,
)
QDRANT_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Qdrant server",
description="Qdrant api key",
default=None,
)
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Timeout in seconds for Qdrant client operations (default is 20 seconds)",
description="Qdrant client timeout in seconds",
default=20,
)
QDRANT_GRPC_ENABLED: bool = Field(
description="Whether to enable gRPC support for Qdrant connection (True for gRPC, False for HTTP)",
description="whether enable grpc support for Qdrant connection",
default=False,
)
QDRANT_GRPC_PORT: PositiveInt = Field(
description="Port number for gRPC connection to Qdrant server (default is 6334)",
description="Qdrant grpc port",
default=6334,
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class RelytConfig(BaseSettings):
"""
Configuration settings for Relyt database
Relyt configs
"""
RELYT_HOST: Optional[str] = Field(
description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')",
description="Relyt host",
default=None,
)
RELYT_PORT: PositiveInt = Field(
description="Port number on which the Relyt server is listening (default is 9200)",
description="Relyt port",
default=9200,
)
RELYT_USER: Optional[str] = Field(
description="Username for authenticating with the Relyt database",
description="Relyt user",
default=None,
)
RELYT_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Relyt database",
description="Relyt password",
default=None,
)
RELYT_DATABASE: Optional[str] = Field(
description="Name of the Relyt database to connect to (default is 'default')",
description="Relyt database",
default="default",
)

View File

@@ -6,45 +6,45 @@ from pydantic_settings import BaseSettings
class TencentVectorDBConfig(BaseSettings):
"""
Configuration settings for Tencent Vector Database
Tencent Vector configs
"""
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')",
description="Tencent Vector URL",
default=None,
)
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Tencent Vector Database service",
description="Tencent Vector API key",
default=None,
)
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for Tencent Vector Database operations (default is 30 seconds)",
description="Tencent Vector timeout in seconds",
default=30,
)
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description="Username for authenticating with the Tencent Vector Database (if required)",
description="Tencent Vector username",
default=None,
)
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Tencent Vector Database (if required)",
description="Tencent Vector password",
default=None,
)
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
description="Number of shards for the Tencent Vector Database (default is 1)",
description="Tencent Vector sharding number",
default=1,
)
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Number of replicas for the Tencent Vector Database (default is 2)",
description="Tencent Vector replicas",
default=2,
)
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Name of the specific Tencent Vector Database to connect to",
description="Tencent Vector Database",
default=None,
)

View File

@@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class TiDBVectorConfig(BaseSettings):
"""
Configuration settings for TiDB Vector database
TiDB Vector configs
"""
TIDB_VECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')",
description="TiDB Vector host",
default=None,
)
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the TiDB Vector server is listening (default is 4000)",
description="TiDB Vector port",
default=4000,
)
TIDB_VECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the TiDB Vector database",
description="TiDB Vector user",
default=None,
)
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the TiDB Vector database",
description="TiDB Vector password",
default=None,
)
TIDB_VECTOR_DATABASE: Optional[str] = Field(
description="Name of the TiDB Vector database to connect to",
description="TiDB Vector database",
default=None,
)

View File

@@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
class WeaviateConfig(BaseSettings):
"""
Configuration settings for Weaviate vector database
Weaviate configs
"""
WEAVIATE_ENDPOINT: Optional[str] = Field(
description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')",
description="Weaviate endpoint URL",
default=None,
)
WEAVIATE_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Weaviate server",
description="Weaviate API key",
default=None,
)
WEAVIATE_GRPC_ENABLED: bool = Field(
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
description="whether to enable gRPC for Weaviate connection",
default=True,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
description="Weaviate batch size",
default=100,
)

View File

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.8.3",
default="0.8.0-beta1",
)
COMMIT_SHA: str = Field(

View File

@@ -1,2 +1 @@
HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"

File diff suppressed because one or more lines are too long

View File

@@ -60,15 +60,23 @@ class InsertExploreAppListApi(Resource):
site = app.site
if not site:
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
desc = args["desc"] if args["desc"] else ""
copy_right = args["copyright"] if args["copyright"] else ""
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
else:
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
desc = site.description if site.description else args["desc"] if args["desc"] else ""
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
privacy_policy = (
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()

View File

@@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id):
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = (

View File

@@ -94,15 +94,19 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)

View File

@@ -109,7 +109,6 @@ class ChatMessageApi(Resource):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()

View File

@@ -20,7 +20,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.helper import DatetimeString
from libs.helper import datetime_string
from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
@@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
@@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
@@ -201,11 +201,7 @@ class ChatConversationApi(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
@@ -214,11 +210,7 @@ class ChatConversationApi(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
query = query.where(Conversation.created_at < end_datetime_utc)
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join(

View File

@@ -105,6 +105,8 @@ class ChatMessageListApi(Resource):
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)

View File

@@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.helper import datetime_string
from libs.login import login_required
from models.model import AppMode
@@ -25,17 +25,14 @@ class DailyMessageStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(*) AS message_count
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@@ -48,7 +45,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -58,10 +55,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -82,17 +79,14 @@ class DailyConversationStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT messages.conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@@ -105,7 +99,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -115,10 +109,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -139,17 +133,14 @@ class DailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@@ -162,7 +153,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -172,10 +163,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -196,18 +187,16 @@ class DailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
SUM(total_price) AS total_price
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
sum(total_price) as total_price
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@@ -220,7 +209,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -230,10 +219,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -256,26 +245,16 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(subquery.message_count) AS interactions
FROM
(
SELECT
m.conversation_id,
COUNT(m.id) AS message_count
FROM
conversations c
JOIN
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(subquery.message_count) AS interactions
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
FROM conversations c
JOIN messages m ON c.id = m.conversation_id
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@@ -288,7 +267,7 @@ FROM
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at >= :start"
sql_query += " and c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -298,19 +277,14 @@ FROM
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at < :end"
sql_query += " and c.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += """
GROUP BY m.conversation_id
) subquery
LEFT JOIN
conversations c
ON c.id = subquery.conversation_id
GROUP BY
date
ORDER BY
date"""
GROUP BY m.conversation_id) subquery
LEFT JOIN conversations c on c.id=subquery.conversation_id
GROUP BY date
ORDER BY date"""
response_data = []
@@ -333,21 +307,17 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) AS message_count,
COUNT(mf.id) AS feedback_count
FROM
messages m
LEFT JOIN
message_feedbacks mf
ON mf.message_id=m.id AND mf.rating='like'
WHERE
m.app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
FROM messages m
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
WHERE m.app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@@ -360,7 +330,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at >= :start"
sql_query += " and m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -370,10 +340,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at < :end"
sql_query += " and m.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -399,17 +369,16 @@ class AverageResponseTimeStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) AS latency
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) as latency
FROM messages
WHERE app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@@ -422,7 +391,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -432,10 +401,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -456,20 +425,17 @@ class TokensPerSecondStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
END as tokens_per_second
FROM
messages
WHERE
app_id = :app_id"""
FROM messages
WHERE app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@@ -482,7 +448,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -492,10 +458,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []

View File

@@ -166,8 +166,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
parser.add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("files", type=list, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
args = parser.parse_args()
try:
@@ -467,6 +465,6 @@ api.add_resource(
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
)
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

View File

@@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.helper import datetime_string
from libs.login import login_required
from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom
@@ -26,18 +26,16 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(id) AS runs
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@@ -54,7 +52,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -64,10 +62,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -88,18 +86,16 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@@ -116,7 +112,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -126,10 +122,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -150,18 +146,18 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) AS token_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
sql_query = """
SELECT
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) as token_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@@ -178,7 +174,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@@ -188,10 +184,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@@ -217,31 +213,27 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM
(
SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM
workflow_runs c
WHERE
c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY
date, c.created_by
) sub
GROUP BY
sub.date"""
sql_query = """
SELECT
AVG(sub.interactions) as interactions,
sub.date
FROM
(SELECT
date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM workflow_runs c
WHERE c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY date, c.created_by) sub
GROUP BY sub.date
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@@ -270,7 +262,7 @@ GROUP BY
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
arg_dict["end"] = end_datetime_utc
else:
sql_query = sql_query.replace("{{end}}", "")

View File

@@ -8,7 +8,7 @@ from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.helper import StrLen, email, timezone
from libs.helper import email, str_len, timezone
from libs.password import hash_password, valid_password
from models.account import AccountStatus
from services.account_service import RegisterService
@@ -37,7 +37,7 @@ class ActivateApi(Resource):
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"

View File

@@ -71,7 +71,7 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info)
# Check account status
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value:
@@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account:
# Create account
account_name = user_info.name or "Dify"
account_name = user_info.name if user_info.name else "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)

View File

@@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
@@ -399,7 +399,7 @@ class DatasetIndexingEstimateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -550,7 +550,12 @@ class DatasetApiBaseUrlApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
return {
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
}
class DatasetRetrievalSettingApi(Resource):

View File

@@ -302,8 +302,6 @@ class DatasetInitApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -311,8 +309,6 @@ class DatasetInitApi(Resource):
raise Forbidden()
if args["indexing_technique"] == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
model_manager.get_default_model_instance(
@@ -354,7 +350,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in ["completed", "error"]:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@@ -421,7 +417,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = []
extract_settings = []
for document in documents:
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in ["completed", "error"]:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
# format document files info
@@ -665,7 +661,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
elif action == "resume":
if document.indexing_status not in {"paused", "error"}:
if document.indexing_status not in ["paused", "error"]:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None

View File

@@ -18,7 +18,9 @@ class NotSetupError(BaseHTTPException):
class NotInitValidateError(BaseHTTPException):
error_code = "not_init_validated"
description = "Init validation has not been completed yet. Please proceed with the init validation process first."
description = (
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
)
code = 401

View File

@@ -81,15 +81,19 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)

View File

@@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -100,7 +100,6 @@ class ChatApi(InstalledAppResource):
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
@@ -141,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
def get(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
def post(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at,
"editable": current_user.role in {"owner", "admin"},
"editable": current_user.role in ["owner", "admin"],
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
}
for installed_app in installed_apps

View File

@@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -51,7 +51,7 @@ class MessageListApi(InstalledAppResource):
try:
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
message_id = str(message_id)

View File

@@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters."""
app_model = installed_app.app
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@@ -4,7 +4,7 @@ from flask import session
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import StrLen
from libs.helper import str_len
from models.model import DifySetup
from services.account_service import TenantService
@@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument("password", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=str_len(30), required=True, location="json")
input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"):

View File

@@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import StrLen, email, get_remote_ip
from libs.helper import email, get_remote_ip, str_len
from libs.password import valid_password
from models.model import DifySetup
from services.account_service import RegisterService, TenantService
@@ -40,7 +40,7 @@ class SetupApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=StrLen(30), required=True, location="json")
parser.add_argument("name", type=str_len(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()

View File

@@ -13,7 +13,7 @@ from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 50:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 50 characters.")
return name

View File

@@ -218,7 +218,7 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
)
api.add_resource(

View File

@@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args["provider_name"] or "",
args["provider_name"] if args["provider_name"] else "",
args["tool_name"],
args["credentials"],
args["parameters"],

View File

@@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
raise TooManyFilesError()
extension = file.filename.split(".")[-1]
if extension.lower() not in {"svg", "png"}:
if extension.lower() not in ["svg", "png"]:
raise UnsupportedFileTypeError()
try:

View File

@@ -64,8 +64,7 @@ def cloud_edition_billing_resource_check(resource: str):
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
source = request.args.get("source")
if source == "datasets":
abort(403, "The number of documents has reached the limit of your subscription.")

View File

@@ -42,7 +42,7 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields)
def get(self, app_model: App):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@@ -79,15 +79,19 @@ class TextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None
response = AudioService.transcript_tts(

View File

@@ -96,7 +96,7 @@ class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -144,7 +144,7 @@ class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

View File

@@ -18,7 +18,7 @@ class ConversationApi(Resource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
@marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
@marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@@ -54,7 +54,6 @@ class MessageListApi(Resource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
@@ -77,7 +76,7 @@ class MessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -118,7 +117,7 @@ class MessageSuggestedApi(Resource):
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
try:

View File

@@ -1,7 +1,6 @@
import logging
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@@ -23,12 +22,10 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService
from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
@@ -116,30 +113,6 @@ class WorkflowTaskStopApi(Resource):
return {"result": "success"}
class WorkflowAppLogApi(Resource):
@validate_app_token
@marshal_with(workflow_app_log_pagination_fields)
def get(self, app_model: App):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args
)
return workflow_app_log_pagination
api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

View File

@@ -37,7 +37,7 @@ class SegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
raise NotFound("Document is not completed.")
raise NotFound("Document is already completed.")
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting
@@ -67,7 +67,7 @@ class SegmentApi(DatasetApiResource):
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
return {"error": "Segemtns is required"}, 400
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""

View File

@@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@@ -78,15 +78,19 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None

View File

@@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -96,7 +96,6 @@ class ChatApi(WebApiResource):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
args = parser.parse_args()
@@ -137,7 +136,7 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

View File

@@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@@ -57,7 +57,6 @@ class MessageListApi(WebApiResource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
@@ -79,7 +78,7 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@@ -90,7 +89,7 @@ class MessageListApi(WebApiResource):
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -161,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotCompletionAppError()
message_id = str(message_id)

View File

@@ -80,8 +80,7 @@ def _validate_web_sso_token(decoded, system_features, app_code):
if not source or source != "sso":
raise WebSSOAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO,
# raise an error and redirect to normal passport login
# Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
source = decoded.get("token_source")
if source and source == "sso":

View File

@@ -1 +1 @@
import core.moderation.base
import core.moderation.base

View File

@@ -1,7 +1,6 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Optional, Union, cast
@@ -32,7 +31,6 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.entities.tool_entities import (
ToolParameter,
ToolRuntimeVariablePool,
@@ -47,25 +45,22 @@ from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner):
def __init__(
self,
tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
def __init__(self, tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None:
"""
Agent runner
:param tenant_id: tenant id
@@ -93,7 +88,9 @@ class BaseAgentRunner(AppRunner):
self.message = message
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.history_prompt_messages = self.organize_agent_history(
prompt_messages=prompt_messages or []
)
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
@@ -114,16 +111,12 @@ class BaseAgentRunner(AppRunner):
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback,
hit_callback=hit_callback
)
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.filter(
MessageAgentThought.message_id == self.message.id,
)
.count()
)
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
db.session.close()
# check if model supports stream tool call
@@ -142,26 +135,25 @@ class BaseAgentRunner(AppRunner):
self.query = None
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
"""
Repack app generate entity
"""
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
return app_generate_entity
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
"""
convert tool to prompt message tool
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from,
invoke_from=self.application_generate_entity.invoke_from
)
tool_entity.load_variables(self.variables_pool)
@@ -172,7 +164,7 @@ class BaseAgentRunner(AppRunner):
"type": "object",
"properties": {},
"required": [],
},
}
)
parameters = tool_entity.get_all_runtime_parameters()
@@ -185,19 +177,19 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
message_tool.parameters["properties"][parameter.name] = {
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
"description": parameter.llm_description or '',
}
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
message_tool.parameters['required'].append(parameter.name)
return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
"""
convert dataset retriever tool to prompt message tool
@@ -209,24 +201,24 @@ class BaseAgentRunner(AppRunner):
"type": "object",
"properties": {},
"required": [],
},
}
)
for parameter in tool.get_runtime_parameters():
parameter_type = "string"
prompt_tool.parameters["properties"][parameter.name] = {
parameter_type = 'string'
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
"description": parameter.llm_description or '',
}
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
"""
Init tools
"""
@@ -269,51 +261,51 @@ class BaseAgentRunner(AppRunner):
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
prompt_tool.parameters["properties"][parameter.name] = {
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
"description": parameter.llm_description or '',
}
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
def create_agent_thought(
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
def create_agent_thought(self, message_id: str, message: str,
tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
"""
Create agent thought
"""
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
thought="",
thought='',
tool=tool_name,
tool_labels_str="{}",
tool_meta_str="{}",
tool_labels_str='{}',
tool_meta_str='{}',
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
message_files=json.dumps(messages_ids) if messages_ids else '',
answer='',
observation='',
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
tokens=0,
total_price=0,
position=self.agent_thought_count + 1,
currency="USD",
currency='USD',
latency=0,
created_by_role="account",
created_by_role='account',
created_by=self.user_id,
)
@@ -326,22 +318,22 @@ class BaseAgentRunner(AppRunner):
return thought
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
def save_agent_thought(self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None) -> MessageAgentThought:
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None:
agent_thought.thought = thought
@@ -364,7 +356,7 @@ class BaseAgentRunner(AppRunner):
observation = json.dumps(observation, ensure_ascii=False)
except Exception as e:
observation = json.dumps(observation)
agent_thought.observation = observation
if answer is not None:
@@ -372,7 +364,7 @@ class BaseAgentRunner(AppRunner):
if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
@@ -385,7 +377,7 @@ class BaseAgentRunner(AppRunner):
# check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else []
tools = agent_thought.tool.split(';') if agent_thought.tool else []
for tool in tools:
if not tool:
continue
@@ -394,7 +386,7 @@ class BaseAgentRunner(AppRunner):
if tool_label:
labels[tool] = tool_label.to_dict()
else:
labels[tool] = {"en_US": tool, "zh_Hans": tool}
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
agent_thought.tool_labels_str = json.dumps(labels)
@@ -409,18 +401,14 @@ class BaseAgentRunner(AppRunner):
db.session.commit()
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
db_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
@@ -437,16 +425,9 @@ class BaseAgentRunner(AppRunner):
if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
messages: list[Message] = (
db.session.query(Message)
.filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.desc())
.all()
)
messages = list(reversed(extract_thread_messages(messages)))
messages: list[Message] = db.session.query(Message).filter(
Message.conversation_id == self.message.conversation_id,
).order_by(Message.created_at.asc()).all()
for message in messages:
if message.id == self.message.id:
@@ -458,13 +439,13 @@ class BaseAgentRunner(AppRunner):
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tools = tools.split(';')
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e:
tool_inputs = {tool: {} for tool in tools}
tool_inputs = { tool: {} for tool in tools }
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
@@ -473,33 +454,27 @@ class BaseAgentRunner(AppRunner):
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
),
)
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
tool_calls.append(AssistantPromptMessage.ToolCall(
id=tool_call_id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
tool_call_id=tool_call_id,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
)
))
tool_call_response.append(ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool,
tool_call_id=tool_call_id,
))
result.extend(
[
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response,
]
)
result.extend([
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response
])
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
@@ -521,7 +496,10 @@ class BaseAgentRunner(AppRunner):
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
file_objs = message_file_parser.transform_message_files(
files,
file_extra_config
)
else:
file_objs = []

View File

@@ -25,19 +25,17 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_ignore_observation_providers = ['wenxin']
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
def run(
self,
message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
@@ -48,16 +46,17 @@ class CotAgentRunner(BaseAgentRunner, ABC):
trace_manager = app_generate_entity.trace_manager
# check model mode
if "Observation" not in app_generate_entity.model_conf.stop:
if 'Observation' not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation")
app_generate_entity.model_conf.stop.append('Observation')
app_config = self.app_config
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@@ -66,14 +65,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage = {"usage": None}
final_answer = ""
llm_usage = {
'usage': None
}
final_answer = ''
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@@ -93,13 +94,17 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
@@ -120,20 +125,21 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
agent_response='',
thought='',
action_str='',
observation='',
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
@@ -148,51 +154,61 @@ class CotAgentRunner(BaseAgentRunner, ABC):
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
)
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
self._agent_scratchpad.append(scratchpad)
# get llm usage
if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict["usage"])
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
usage_dict['usage'] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
observation="",
observation='',
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict["usage"],
llm_usage=usage_dict['usage']
)
if not scratchpad.is_final():
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ""
final_answer = ''
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input)
final_answer = json.dumps(
scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
final_answer = f'{scratchpad.action.action_input}'
except json.JSONDecodeError:
final_answer = f"{scratchpad.action.action_input}"
final_answer = f'{scratchpad.action.action_input}'
else:
function_call_state = True
# action is tool call, invoke tool
@@ -208,18 +224,21 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
observation={
scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict["usage"],
llm_usage=usage_dict['usage']
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
@@ -231,45 +250,44 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
index=0,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
),
system_fingerprint="",
system_fingerprint=''
)
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name="",
tool_name='',
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[],
messages_ids=[]
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
),
PublishFrom.APPLICATION_MANAGER,
)
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
@@ -308,12 +326,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file_id)
@@ -323,7 +342,10 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
return AgentScratchpadUnit.Action(
action_name=action['action'],
action_input=action['action_input']
)
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
@@ -331,7 +353,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
except Exception as e:
continue
@@ -348,14 +370,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
format assistant message
"""
message = ""
message = ''
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
@@ -368,11 +390,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
organize historic prompt messages
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
@@ -383,8 +403,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
thought=message.content or 'I am thinking about how to help you',
action_str='',
action=None,
observation=None,
)
@@ -393,9 +413,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
action_input=json.loads(
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except:
pass
elif isinstance(message, ToolPromptMessage):
@@ -403,19 +426,23 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory,
memory=self.memory
).get_prompt()
return historic_prompts

View File

@@ -19,15 +19,14 @@ class CotChatAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
system_prompt = first_prompt \
.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -44,7 +43,7 @@ class CotChatAgentRunner(CotAgentRunner):
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
@@ -54,7 +53,7 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
assistant_message = AssistantPromptMessage(content='')
for unit in agent_scratchpad:
if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}"
@@ -72,15 +71,18 @@ class CotChatAgentRunner(CotAgentRunner):
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
historic_messages = self._organize_historic_prompt_messages([
system_message,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content="continue"),
UserPromptMessage(content='continue')
]
else:
# organize historic prompt messages

View File

@@ -13,12 +13,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@@ -48,7 +46,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ""
assistant_prompt = ''
for unit in agent_scratchpad:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
@@ -63,10 +61,9 @@ class CotCompletionAgentRunner(CotAgentRunner):
query_prompt = f"Question: {self._query}"
# join all messages
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
prompt = system_prompt \
.replace("{{historic_messages}}", historic_prompt) \
.replace("{{agent_scratchpad}}", assistant_prompt) \
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]
return [UserPromptMessage(content=prompt)]

View File

@@ -8,7 +8,6 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
@@ -19,7 +18,6 @@ class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
@@ -33,7 +31,6 @@ class AgentScratchpadUnit(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
@@ -42,8 +39,8 @@ class AgentScratchpadUnit(BaseModel):
Convert to dictionary.
"""
return {
"action": self.action_name,
"action_input": self.action_input,
'action': self.action_name,
'action_input': self.action_input,
}
agent_response: Optional[str] = None
@@ -57,10 +54,10 @@ class AgentScratchpadUnit(BaseModel):
Check if the scratchpad unit is final.
"""
return self.action is None or (
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
'final' in self.action.action_name.lower() and
'answer' in self.action.action_name.lower()
)
class AgentEntity(BaseModel):
"""
Agent Entity.
@@ -70,9 +67,8 @@ class AgentEntity(BaseModel):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
provider: str
model: str

View File

@@ -24,9 +24,11 @@ from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
@@ -43,17 +45,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {"usage": None}
final_answer = ""
llm_usage = {
'usage': None
}
final_answer = ''
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@@ -71,7 +75,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
)
# recalc llm max tokens
@@ -91,11 +99,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ""
response = ''
# save tool call names and inputs
tool_call_names = ""
tool_call_inputs = ""
tool_call_names = ''
tool_call_inputs = ''
current_llm_usage = None
@@ -103,22 +111,24 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
@@ -138,14 +148,16 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if result.usage:
increase_usage(llm_usage, result.usage)
@@ -159,12 +171,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
response += result.message.content
if not result.message.content:
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
result.message.content = ''
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
@@ -173,29 +185,32 @@ class FunctionCallAgentRunner(BaseAgentRunner):
index=0,
message=result.message,
usage=result.usage,
),
)
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
assistant_message = AssistantPromptMessage(
content='',
tool_calls=[]
)
if tool_calls:
assistant_message.tool_calls = [
assistant_message.tool_calls=[
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type="function",
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
),
)
for tool_call in tool_calls
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
@@ -203,13 +218,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage,
llm_usage=current_llm_usage
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
final_answer += response + '\n'
# call tools
tool_responses = []
@@ -220,7 +235,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
}
else:
# invoke tool
@@ -240,49 +255,50 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict(),
"meta": tool_invoke_meta.to_dict()
}
tool_responses.append(tool_response)
if tool_response["tool_response"] is not None:
if tool_response['tool_response'] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response["tool_response"],
content=tool_response['tool_response'],
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=None,
tool_input=None,
thought=None,
thought=None,
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
tool_response['tool_call_name']: tool_response['meta']
for tool_response in tool_responses
},
observation={
tool_response["tool_call_name"]: tool_response["tool_response"]
tool_response['tool_call_name']: tool_response['tool_response']
for tool_response in tool_responses
},
answer=None,
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
messages_ids=message_file_ids
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool
for prompt_tool in prompt_messages_tools:
@@ -292,18 +308,15 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
),
PublishFrom.APPLICATION_MANAGER,
)
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
@@ -312,7 +325,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
@@ -321,9 +334,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
@@ -333,19 +344,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
if prompt_message.function.arguments != '':
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
@@ -356,22 +365,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
if prompt_message.function.arguments != '':
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
return tool_calls
def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Initialize system message
"""
@@ -379,13 +384,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -399,7 +404,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
@@ -410,21 +415,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
prompt_message.content = '\n'.join([
content.data if content.type == PromptMessageContentType.TEXT else
'[image]' if content.type == PromptMessageContentType.IMAGE else
'[file]'
for content in prompt_message.content
])
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])
@@ -432,10 +433,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
memory=self.memory
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

View File

@@ -9,9 +9,8 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str):
try:
action = json.loads(json_str)
@@ -23,7 +22,7 @@ class CotAgentOutputParser:
action = action[0]
for key, value in action.items():
if "input" in key.lower():
if 'input' in key.lower():
action_input = value
else:
action_name = value
@@ -34,37 +33,37 @@ class CotAgentOutputParser:
action_input=action_input,
)
else:
return json_str or ""
return json_str or ''
except:
return json_str or ""
return json_str or ''
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
yield parse_action(json_text)
code_block_cache = ""
code_block_cache = ''
code_block_delimiter_count = 0
in_code_block = False
json_cache = ""
json_cache = ''
json_quote_count = 0
in_json = False
got_json = False
action_cache = ""
action_str = "action:"
action_cache = ''
action_str = 'action:'
action_idx = 0
thought_cache = ""
thought_str = "thought:"
thought_cache = ''
thought_str = 'thought:'
thought_idx = 0
for response in llm_response:
if response.delta.usage:
usage_dict["usage"] = response.delta.usage
usage_dict['usage'] = response.delta.usage
response = response.delta.message.content
if not isinstance(response, str):
continue
@@ -73,24 +72,24 @@ class CotAgentOutputParser:
index = 0
while index < len(response):
steps = 1
delta = response[index : index + steps]
last_character = response[index - 1] if index > 0 else ""
delta = response[index:index+steps]
last_character = response[index-1] if index > 0 else ''
if delta == "`":
if delta == '`':
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ""
code_block_cache = ''
else:
code_block_cache += delta
code_block_delimiter_count = 0
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in {"\n", " ", ""}:
if last_character not in ['\n', ' ', '']:
index += steps
yield delta
continue
@@ -98,7 +97,7 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ""
action_cache = ''
action_idx = 0
index += steps
continue
@@ -106,18 +105,18 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ""
action_cache = ''
action_idx = 0
index += steps
continue
else:
if action_cache:
yield action_cache
action_cache = ""
action_cache = ''
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in {"\n", " ", ""}:
if last_character not in ['\n', ' ', '']:
index += steps
yield delta
continue
@@ -125,7 +124,7 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ""
thought_cache = ''
thought_idx = 0
index += steps
continue
@@ -133,31 +132,31 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ""
thought_cache = ''
thought_idx = 0
index += steps
continue
else:
if thought_cache:
yield thought_cache
thought_cache = ""
thought_cache = ''
thought_idx = 0
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ""
code_block_cache = ''
in_code_block = not in_code_block
code_block_delimiter_count = 0
if not in_code_block:
# handle single json
if delta == "{":
if delta == '{':
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == "}":
elif delta == '}':
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
@@ -173,12 +172,12 @@ class CotAgentOutputParser:
if got_json:
got_json = False
yield parse_action(json_cache)
json_cache = ""
json_cache = ''
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace("`", "")
yield delta.replace('`', '')
index += steps
@@ -187,3 +186,4 @@ class CotAgentOutputParser:
if json_cache:
yield parse_action(json_cache)

View File

@@ -41,8 +41,7 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
{{historic_messages}}
Question: {{query}}
{{agent_scratchpad}}
Thought:""" # noqa: E501
Thought:"""
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:"""
@@ -87,20 +86,19 @@ Action:
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
""" # noqa: E501
"""
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = {
"english": {
"chat": {
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
},
"completion": {
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
'english': {
'chat': {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
},
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
}
}
}

View File

@@ -26,24 +26,34 @@ class BaseAppConfigManager:
config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
config=config_dict
)
additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
config=config_dict,
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
)
additional_features.opening_statement, additional_features.suggested_questions = (
OpeningStatementConfigManager.convert(config=config_dict)
)
additional_features.opening_statement, additional_features.suggested_questions = \
OpeningStatementConfigManager.convert(
config=config_dict
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict
)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
config=config_dict
)
return additional_features

View File

@@ -7,24 +7,25 @@ from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
if not sensitive_word_avoidance_dict:
return None
if sensitive_word_avoidance_dict.get("enabled"):
if sensitive_word_avoidance_dict.get('enabled'):
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config"),
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
)
else:
return None
@classmethod
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
-> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}
config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
@@ -40,6 +41,10 @@ class SensitiveWordAvoidanceConfigManager:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
ModerationFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config
)
return config, ["sensitive_word_avoidance"]

View File

@@ -12,70 +12,67 @@ class AgentConfigManager:
:param config: model config args
"""
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode']:
if agent_strategy == "function_call":
agent_dict = config.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy in {"cot", "react"}:
elif agent_strategy == 'cot' or agent_strategy == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
if config['model']['provider'] == 'openai':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get("tools", []):
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
"provider_type": tool["provider_type"],
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
'provider_type': tool['provider_type'],
'provider_id': tool['provider_id'],
'tool_name': tool['tool_name'],
'tool_parameters': tool.get('tool_parameters', {})
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router",
"router",
}:
agent_prompt = agent_dict.get("prompt", None) or {}
if 'strategy' in config['agent_mode'] and \
config['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {}
# check model mode
model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == "completion":
model_mode = config.get('model', {}).get('mode', 'completion')
if model_mode == 'completion':
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get(
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
),
next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
),
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['completion'][
'agent_scratchpad']),
)
else:
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get(
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
),
next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
),
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
)
return AgentEntity(
provider=config["model"]["provider"],
model=config["model"]["name"],
provider=config['model']['provider'],
model=config['model']['name'],
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get("max_iteration", 5),
max_iteration=agent_dict.get('max_iteration', 5)
)
return None

View File

@@ -15,38 +15,39 @@ class DatasetConfigManager:
:param config: model config args
"""
dataset_ids = []
if "datasets" in config.get("dataset_configs", {}):
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
if 'datasets' in config.get('dataset_configs', {}):
datasets = config.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get("datasets", []):
for dataset in datasets.get('datasets', []):
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != "dataset":
if len(keys) == 0 or keys[0] != 'dataset':
continue
dataset = dataset["dataset"]
dataset = dataset['dataset']
if "enabled" not in dataset or not dataset["enabled"]:
if 'enabled' not in dataset or not dataset['enabled']:
continue
dataset_id = dataset.get("id", None)
dataset_id = dataset.get('id', None)
if dataset_id:
dataset_ids.append(dataset_id)
if (
"agent_mode" in config
and config["agent_mode"]
and "enabled" in config["agent_mode"]
and config["agent_mode"]["enabled"]
):
agent_dict = config.get("agent_mode", {})
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode'] \
and config['agent_mode']['enabled']:
for tool in agent_dict.get("tools", []):
agent_dict = config.get('agent_mode', {})
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != "dataset":
if key != 'dataset':
continue
tool_item = tool[key]
@@ -54,28 +55,30 @@ class DatasetConfigManager:
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item["id"]
dataset_id = tool_item['id']
dataset_ids.append(dataset_id)
if len(dataset_ids) == 0:
return None
# dataset configs
if "dataset_configs" in config and config.get("dataset_configs"):
dataset_configs = config.get("dataset_configs")
if 'dataset_configs' in config and config.get('dataset_configs'):
dataset_configs = config.get('dataset_configs')
else:
dataset_configs = {"retrieval_model": "multiple"}
query_variable = config.get("dataset_query_variable")
dataset_configs = {
'retrieval_model': 'multiple'
}
query_variable = config.get('dataset_query_variable')
if dataset_configs["retrieval_model"] == "single":
if dataset_configs['retrieval_model'] == 'single':
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
),
dataset_configs['retrieval_model']
)
)
)
else:
return DatasetEntity(
@@ -83,15 +86,15 @@ class DatasetConfigManager:
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get("score_threshold"),
reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
),
top_k=dataset_configs.get('top_k', 4),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
)
)
@classmethod
@@ -108,10 +111,13 @@ class DatasetConfigManager:
# dataset_configs
if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"}
config["dataset_configs"] = {'retrieval_model': 'single'}
if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
config["dataset_configs"]["datasets"] = {
"strategy": "router",
"datasets": []
}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
@@ -119,9 +125,8 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {}
).get("datasets")
need_manual_query_datasets = (config.get("dataset_configs")
and config["dataset_configs"].get("datasets", {}).get("datasets"))
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
@@ -143,7 +148,10 @@ class DatasetConfigManager:
"""
# Extract dataset config for legacy compatibility
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
config["agent_mode"] = {
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@@ -167,7 +175,7 @@ class DatasetConfigManager:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0]
if key == "dataset":
@@ -180,7 +188,7 @@ class DatasetConfigManager:
if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if "id" not in tool_item:
if 'id' not in tool_item:
raise ValueError("id is required in dataset")
try:

View File

@@ -11,7 +11,9 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \
-> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
@@ -23,7 +25,9 @@ class ModelConfigConverter:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
tenant_id=app_config.tenant_id,
provider=model_config.provider,
model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
@@ -34,7 +38,8 @@ class ModelConfigConverter:
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_config.model
model_type=ModelType.LLM,
model=model_config.model
)
if model_credentials is None:
@@ -46,7 +51,8 @@ class ModelConfigConverter:
if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
model=model_config.model,
model_type=ModelType.LLM
)
if provider_model is None:
@@ -63,18 +69,24 @@ class ModelConfigConverter:
# model config
completion_params = model_config.parameters
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = model_config.mode
if not model_mode:
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
mode_enum = model_type_instance.get_model_mode(
model=model_config.model,
credentials=model_credentials
)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
model_schema = model_type_instance.get_model_schema(
model_config.model,
model_credentials
)
if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@@ -13,23 +13,23 @@ class ModelConfigManager:
:param config: model config args
"""
# model config
model_config = config.get("model")
model_config = config.get('model')
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get("completion_params")
completion_params = model_config.get('completion_params')
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = model_config.get("mode")
model_mode = model_config.get('mode')
return ModelConfigEntity(
provider=config["model"]["provider"],
model=config["model"]["name"],
provider=config['model']['provider'],
model=config['model']['name'],
mode=model_mode,
parameters=completion_params,
stop=stop,
@@ -43,7 +43,7 @@ class ModelConfigManager:
:param tenant_id: tenant id
:param config: app model config args
"""
if "model" not in config:
if 'model' not in config:
raise ValueError("model is required")
if not isinstance(config["model"], dict):
@@ -52,16 +52,17 @@ class ModelConfigManager:
# model.provider
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name
if "name" not in config["model"]:
if 'name' not in config["model"]:
raise ValueError("model.name is required")
provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], model_type=ModelType.LLM
provider=config["model"]["provider"],
model_type=ModelType.LLM
)
if not models:
@@ -79,12 +80,12 @@ class ModelConfigManager:
# model.mode
if model_mode:
config["model"]["mode"] = model_mode
config['model']["mode"] = model_mode
else:
config["model"]["mode"] = "completion"
config['model']["mode"] = "completion"
# model.completion_params
if "completion_params" not in config["model"]:
if 'completion_params' not in config["model"]:
raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = cls.validate_model_completion_params(
@@ -100,7 +101,7 @@ class ModelConfigManager:
raise ValueError("model.completion_params must be of object type")
# stop
if "stop" not in cp:
if 'stop' not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")

View File

@@ -14,33 +14,39 @@ class PromptTemplateConfigManager:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
return PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else:
advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append(
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
"prompt": completion_prompt_config["prompt"]["text"],
'prompt': completion_prompt_config['prompt']['text'],
}
if "conversation_histories_role" in completion_prompt_config:
completion_prompt_template_params["role_prefix"] = {
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
@@ -50,7 +56,7 @@ class PromptTemplateConfigManager:
return PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
)
@classmethod
@@ -66,7 +72,7 @@ class PromptTemplateConfigManager:
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals:
if config['prompt_type'] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
# chat_prompt_config
@@ -83,28 +89,27 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
)
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required "
"when prompt_type is advanced")
model_mode_vals = [mode.value for mode in ModelMode]
if config["model"]["mode"] not in model_mode_vals:
if config['model']["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if not user_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
if not assistant_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config["model"]["mode"] == ModelMode.CHAT.value:
prompt_list = config["chat_prompt_config"]["prompt"]
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")

View File

@@ -16,49 +16,51 @@ class BasicVariablesConfigManager:
variable_entities = []
# old external_data_tools
external_data_tools = config.get("external_data_tools", [])
external_data_tools = config.get('external_data_tools', [])
for external_data_tool in external_data_tools:
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=external_data_tool["variable"],
type=external_data_tool["type"],
config=external_data_tool["config"],
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
)
)
# variables and external_data_tools
for variables in config.get("user_input_form", []):
for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if "config" not in variable:
if 'config' not in variable:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=variable["variable"], type=variable["type"], config=variable["config"]
variable=variable['variable'],
type=variable['type'],
config=variable['config']
)
)
elif variable_type in {
elif variable_type in [
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
}:
]:
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description"),
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options"),
default=variable.get("default"),
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
)
)
@@ -97,17 +99,17 @@ class BasicVariablesConfigManager:
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key]
if "label" not in form_item:
if 'label' not in form_item:
raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type")
if "variable" not in form_item:
if 'variable' not in form_item:
raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str):
@@ -115,24 +117,26 @@ class BasicVariablesConfigManager:
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, and cannot start with a number")
raise ValueError("variable in user_input_form must be a string, "
"and cannot start with a number")
variables.append(form_item["variable"])
if "required" not in form_item or not form_item["required"]:
if 'required' not in form_item or not form_item["required"]:
form_item["required"] = False
if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type")
if key == "select":
if "options" not in form_item or not form_item["options"]:
if 'options' not in form_item or not form_item["options"]:
form_item["options"] = []
if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
if "default" in form_item and form_item['default'] \
and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"]
@@ -164,6 +168,10 @@ class BasicVariablesConfigManager:
typ = tool["type"]
config = tool["config"]
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
ExternalDataToolFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["external_data_tools"]

Some files were not shown because too many files have changed in this diff Show More