mirror of
https://github.com/langgenius/dify.git
synced 2026-01-09 07:44:12 +00:00
Compare commits
311 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce5b19d011 | ||
|
|
f82a64d149 | ||
|
|
f49b1afd6c | ||
|
|
796c5626a7 | ||
|
|
e54c9cd401 | ||
|
|
f8951d7f57 | ||
|
|
6454e1d644 | ||
|
|
e184c8cb42 | ||
|
|
fdd211e399 | ||
|
|
7001e21e7d | ||
|
|
82d0732c12 | ||
|
|
53cd125780 | ||
|
|
3c91f9b5ab | ||
|
|
f073dca22a | ||
|
|
8b1e35d7dc | ||
|
|
b75d8ca621 | ||
|
|
9beefd7d5a | ||
|
|
88145efa97 | ||
|
|
bdc13f9238 | ||
|
|
ce58f0607b | ||
|
|
bbc0d330a9 | ||
|
|
60e7e17c86 | ||
|
|
237bb8514e | ||
|
|
bd26c933d2 | ||
|
|
b6b58da2d2 | ||
|
|
40c646cf7a | ||
|
|
3231a8c51c | ||
|
|
4170d6a491 | ||
|
|
0b50c525cf | ||
|
|
8ba38e8e74 | ||
|
|
b163545771 | ||
|
|
c0b82f8e58 | ||
|
|
b75ff5fa03 | ||
|
|
9440d7fe88 | ||
|
|
24809fce07 | ||
|
|
9819ad347f | ||
|
|
8fe83750b7 | ||
|
|
1809f05904 | ||
|
|
0ac250a035 | ||
|
|
405a00bb2c | ||
|
|
3a3ca8e6a9 | ||
|
|
27e678480e | ||
|
|
7052565380 | ||
|
|
31070ffbca | ||
|
|
7f3dec7bee | ||
|
|
b1e0db4944 | ||
|
|
c439952a41 | ||
|
|
2f28afebb6 | ||
|
|
fa7ba30ba3 | ||
|
|
1cf5f510ed | ||
|
|
526c874caa | ||
|
|
f88f744097 | ||
|
|
95733796f0 | ||
|
|
552f319b9d | ||
|
|
38e5952417 | ||
|
|
7f891939f1 | ||
|
|
69a5ce1e31 | ||
|
|
534802b761 | ||
|
|
5c258e212c | ||
|
|
6a6133c102 | ||
|
|
3c1825187a | ||
|
|
8523b34be7 | ||
|
|
65cfd4360a | ||
|
|
bbf5f42c87 | ||
|
|
3631e53ff0 | ||
|
|
f322d9bddb | ||
|
|
05ce7b9d5e | ||
|
|
72ddedfc5c | ||
|
|
36686d7425 | ||
|
|
34387ec0f1 | ||
|
|
83a6b0c626 | ||
|
|
76da66fb7e | ||
|
|
607f9eda35 | ||
|
|
f25cec265d | ||
|
|
8e66b96221 | ||
|
|
b5c1bb346c | ||
|
|
e94b323e6c | ||
|
|
bc65ee10c0 | ||
|
|
2001483659 | ||
|
|
444aba55dd | ||
|
|
3f640b1037 | ||
|
|
b07084711c | ||
|
|
fa8ab2134f | ||
|
|
1a677da792 | ||
|
|
b6d61a818e | ||
|
|
8495ffaa45 | ||
|
|
dbd1d79770 | ||
|
|
1910178199 | ||
|
|
839a6a2c8a | ||
|
|
a769edbc89 | ||
|
|
57ffecd0e5 | ||
|
|
801d135390 | ||
|
|
0428f44113 | ||
|
|
7beff3fd5a | ||
|
|
88a095e40e | ||
|
|
dd961985f0 | ||
|
|
d44b05a9e5 | ||
|
|
5bd3b02be6 | ||
|
|
3cf5c1853d | ||
|
|
a4d86496e1 | ||
|
|
90bdc85f8c | ||
|
|
0828873b52 | ||
|
|
816b707a16 | ||
|
|
c9257ab4bf | ||
|
|
69ce3b3d33 | ||
|
|
c4caa7c401 | ||
|
|
dc93a292c3 | ||
|
|
174ee1b646 | ||
|
|
9b1c4f47fb | ||
|
|
582ba45c00 | ||
|
|
f1cbd55007 | ||
|
|
3a34370422 | ||
|
|
29ab244de6 | ||
|
|
920b2c2b40 | ||
|
|
ac96d192a6 | ||
|
|
07fbeb6cf0 | ||
|
|
fc64cdee64 | ||
|
|
0c0e96c55f | ||
|
|
5b953c1ef2 | ||
|
|
562ca45e07 | ||
|
|
6bbd53512e | ||
|
|
e352a8ed1b | ||
|
|
e55225e2bc | ||
|
|
3e63abd335 | ||
|
|
0620fa3094 | ||
|
|
d93288f711 | ||
|
|
ca69af7b97 | ||
|
|
952e13fef8 | ||
|
|
4be3087642 | ||
|
|
49da8a23a8 | ||
|
|
3ad943a9eb | ||
|
|
3082093293 | ||
|
|
b03bbab5ad | ||
|
|
9574730050 | ||
|
|
91ea6fe4ee | ||
|
|
769be13189 | ||
|
|
e42175241e | ||
|
|
12257b438b | ||
|
|
9ecc736c30 | ||
|
|
6c4e6bf1d6 | ||
|
|
97fe817186 | ||
|
|
52b12ed7eb | ||
|
|
d8ab4474b4 | ||
|
|
1ecbd95adf | ||
|
|
cad6e6624f | ||
|
|
3505cbe05c | ||
|
|
e15359e589 | ||
|
|
edb86f5f5a | ||
|
|
adf2651d1f | ||
|
|
5031d64e28 | ||
|
|
ae3ad59b16 | ||
|
|
e6cd7b0467 | ||
|
|
97e9f52331 | ||
|
|
25957d917a | ||
|
|
20b932da97 | ||
|
|
207080babc | ||
|
|
48bacd01cc | ||
|
|
297d0f1f30 | ||
|
|
eedbe1b770 | ||
|
|
5ff6b1da07 | ||
|
|
8b49e0ee2a | ||
|
|
e031ec9359 | ||
|
|
1bd1cd6938 | ||
|
|
81c5a21b8d | ||
|
|
61e4bbabaf | ||
|
|
4cf475680d | ||
|
|
ca4aa340f6 | ||
|
|
767d8a4b05 | ||
|
|
0b8dcaba8f | ||
|
|
af6a318aae | ||
|
|
c6e2900be7 | ||
|
|
963d9b6032 | ||
|
|
b2ee738bb1 | ||
|
|
c8ca3ff404 | ||
|
|
5d8fa2c7af | ||
|
|
58df5e5376 | ||
|
|
348ad1a624 | ||
|
|
73e17d5aa8 | ||
|
|
300d9892a5 | ||
|
|
e47b5b43b8 | ||
|
|
21c9d9e200 | ||
|
|
4f6916c4d8 | ||
|
|
8633957726 | ||
|
|
0850c953b3 | ||
|
|
23e95fd7ab | ||
|
|
e1045f01c6 | ||
|
|
e6d22fc3a0 | ||
|
|
9232244920 | ||
|
|
476eb90a90 | ||
|
|
063191889d | ||
|
|
589099a005 | ||
|
|
a0ec7de058 | ||
|
|
14a19a3da9 | ||
|
|
1b04382a9b | ||
|
|
71e5828d41 | ||
|
|
65a02f7d32 | ||
|
|
acf9174bef | ||
|
|
243ca5b1e2 | ||
|
|
f6059c377c | ||
|
|
41328bde97 | ||
|
|
3242cf5384 | ||
|
|
d8de2017b5 | ||
|
|
843280f82b | ||
|
|
42344795cd | ||
|
|
517f6d1a26 | ||
|
|
70992609d4 | ||
|
|
bf736bc55d | ||
|
|
d4cfd3e7ac | ||
|
|
c2d47cd2e1 | ||
|
|
e1a9e0ac29 | ||
|
|
5e145c1c22 | ||
|
|
714ff3c663 | ||
|
|
f5c08070d9 | ||
|
|
392995ca46 | ||
|
|
805ed84f61 | ||
|
|
5010706d8b | ||
|
|
6278ff0f30 | ||
|
|
56c25bfb78 | ||
|
|
b814f0b7e3 | ||
|
|
65bec16fb3 | ||
|
|
556d1d0390 | ||
|
|
1ebf740908 | ||
|
|
51d359268e | ||
|
|
3f0c515355 | ||
|
|
f95839c785 | ||
|
|
5a004ae429 | ||
|
|
04fb610fe7 | ||
|
|
a667d04e53 | ||
|
|
a8f23ed712 | ||
|
|
ecf947258a | ||
|
|
a58612718e | ||
|
|
cd078a6264 | ||
|
|
9f637ead38 | ||
|
|
b521aafd26 | ||
|
|
a84e15b8cc | ||
|
|
0c330fc020 | ||
|
|
cfbb7bec58 | ||
|
|
3b357f51a6 | ||
|
|
09acf215f0 | ||
|
|
07dd8b94ed | ||
|
|
ef308fd121 | ||
|
|
fce64d760b | ||
|
|
f0c9bb7c91 | ||
|
|
d8672796b0 | ||
|
|
5929e84036 | ||
|
|
83063532a0 | ||
|
|
07279558a5 | ||
|
|
2166473852 | ||
|
|
44397e3062 | ||
|
|
883a0a0e6a | ||
|
|
b5ed81b349 | ||
|
|
625b0afa52 | ||
|
|
2660fbaa20 | ||
|
|
9e37702d24 | ||
|
|
bc11c6a7f2 | ||
|
|
10e9766fd3 | ||
|
|
6d24a2cb87 | ||
|
|
0a4dfaeaf9 | ||
|
|
c0a4fd145c | ||
|
|
70f16e1a0b | ||
|
|
cb27571e9f | ||
|
|
0518da5819 | ||
|
|
d2797abdb4 | ||
|
|
bf3ee660e0 | ||
|
|
68406b9906 | ||
|
|
6f7fd6613a | ||
|
|
6d5b386394 | ||
|
|
1ea18a2922 | ||
|
|
f8f4b961a1 | ||
|
|
57565db531 | ||
|
|
d844420c07 | ||
|
|
34634bddf1 | ||
|
|
c97b7f6748 | ||
|
|
76cc19f525 | ||
|
|
5baaebb3fd | ||
|
|
9d072920da | ||
|
|
965ca36525 | ||
|
|
b4988ce20c | ||
|
|
d3d617239f | ||
|
|
6c3b34a61d | ||
|
|
d76d1adb59 | ||
|
|
cadc6b171e | ||
|
|
fdae2a20ae | ||
|
|
45701a81e9 | ||
|
|
409e0c8e1c | ||
|
|
7076d41b29 | ||
|
|
5a6cb69951 | ||
|
|
11a75ee78a | ||
|
|
b9b692d71d | ||
|
|
d8f8afcbd0 | ||
|
|
8cb62ef31a | ||
|
|
bb5d5fc683 | ||
|
|
2fc0dcc10a | ||
|
|
9fd55157d6 | ||
|
|
6c384dba71 | ||
|
|
9730297381 | ||
|
|
99e80a8ed0 | ||
|
|
26fef2d481 | ||
|
|
c9e65f4221 | ||
|
|
20bd33fada | ||
|
|
bd0af2e921 | ||
|
|
4ab66299d4 | ||
|
|
42227f93c0 | ||
|
|
89fcf4ea7c | ||
|
|
3322710dac | ||
|
|
404bf11d8c | ||
|
|
60a2ecbd17 | ||
|
|
828822243a | ||
|
|
2068ae215e | ||
|
|
d4262ecceb | ||
|
|
8be7d8a635 |
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
4
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
4
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
|
||||
30
.github/pull_request_template.md
vendored
Normal file
30
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# Description
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
|
||||
Fixes # (issue)
|
||||
|
||||
## Type of Change
|
||||
|
||||
Please delete options that are not relevant.
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
|
||||
# How Has This Been Tested?
|
||||
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
|
||||
|
||||
- [ ] TODO
|
||||
|
||||
# Suggested Checklist:
|
||||
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [ ] `optional` I have made corresponding changes to the documentation
|
||||
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] `optional` New and existing unit tests pass locally with my changes
|
||||
33
.github/workflows/style.yml
vendored
33
.github/workflows/style.yml
vendored
@@ -10,14 +10,40 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: ESLint and SuperLinter
|
||||
python-style:
|
||||
name: Python Style
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Python dependencies
|
||||
run: pip install ruff
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check ./api
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
test:
|
||||
name: ESLint and SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
needs: python-style
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
@@ -36,11 +62,10 @@ jobs:
|
||||
yarn run lint
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v5
|
||||
uses: super-linter/super-linter/slim@v6
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: main
|
||||
ERROR_ON_MISSING_EXEC_BIT: true
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
IGNORE_GITIGNORED_FILES: true
|
||||
|
||||
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Run Unit Test For SDKs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [16, 18, 20]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
cache-dependency-path: 'yarn.lock'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: yarn install
|
||||
|
||||
- name: Test
|
||||
run: yarn test
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -145,6 +145,9 @@ docker/volumes/db/data/*
|
||||
docker/volumes/redis/data/*
|
||||
docker/volumes/weaviate/*
|
||||
docker/volumes/qdrant/*
|
||||
docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
||||
22
LICENSE
22
LICENSE
@@ -1,24 +1,26 @@
|
||||
# Dify Open Source License
|
||||
# Open Source License
|
||||
|
||||
The Dify project is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
Dify is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
|
||||
1. Dify is permitted to be used for commercialization, such as using Dify as a "backend-as-a-service" for your other applications, or delivering it to enterprises as an application development platform. However, when the following conditions are met, you must contact the producer to obtain a commercial license:
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify.AI source code to operate a multi-tenant SaaS service that is similar to the Dify.AI service edition.
|
||||
b. LOGO and copyright information: In the process of using Dify, you may not remove or modify the LOGO or copyright information in the Dify console.
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
|
||||
|
||||
Please contact business@dify.ai by email to inquire about licensing matters.
|
||||
|
||||
2. As a contributor, you should agree that your contributed code:
|
||||
2. As a contributor, you should agree that:
|
||||
|
||||
a. The producer can adjust the open-source agreement to be more strict or relaxed.
|
||||
b. Can be used for commercial purposes, such as Dify's cloud business.
|
||||
a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary.
|
||||
b. Your contributed code may be used for commercial purposes, including but not limited to its cloud business operations.
|
||||
|
||||
Apart from this, all other rights and restrictions follow the Apache License 2.0. If you need more detailed information, you can refer to the full version of Apache License 2.0.
|
||||
Apart from the specific conditions mentioned above, all other rights and restrictions follow the Apache License 2.0. Detailed information about the Apache License 2.0 can be found at http://www.apache.org/licenses/LICENSE-2.0.
|
||||
|
||||
The interactive design of this product is protected by appearance patent.
|
||||
|
||||
© 2023 LangGenius, Inc.
|
||||
© 2024 LangGenius, Inc.
|
||||
|
||||
|
||||
----------
|
||||
|
||||
@@ -81,11 +81,17 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
|
||||
# Mail configuration, support: resend
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
SMTP_SERVER=smtp.gmail.com
|
||||
SMTP_PORT=587
|
||||
SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=false
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
@@ -120,4 +126,9 @@ HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
UNSTRUCTURED_API_URL=
|
||||
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
1. Start the docker-compose stack
|
||||
|
||||
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
|
||||
```bash
|
||||
cd ../docker
|
||||
docker-compose -f docker-compose.middleware.yaml -p dify up -d
|
||||
@@ -15,7 +15,7 @@
|
||||
3. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
```bash
|
||||
openssl rand -base64 42
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
3.5 If you use annaconda, create a new environment and activate it
|
||||
```bash
|
||||
@@ -46,7 +46,7 @@
|
||||
```
|
||||
pip install -r requirements.txt --upgrade --force-reinstall
|
||||
```
|
||||
|
||||
|
||||
6. Start backend:
|
||||
```bash
|
||||
flask run --host 0.0.0.0 --port=5001 --debug
|
||||
|
||||
32
api/app.py
32
api/app.py
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
@@ -19,20 +18,32 @@ import threading
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from commands import register_commands
|
||||
from config import CloudEditionConfig, Config
|
||||
from events import event_handlers
|
||||
from extensions import (ext_celery, ext_code_based_extension, ext_database, ext_hosting_provider, ext_login, ext_mail,
|
||||
ext_migrate, ext_redis, ext_sentry, ext_storage)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
|
||||
from commands import register_commands
|
||||
from config import CloudEditionConfig, Config
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_redis,
|
||||
ext_sentry,
|
||||
ext_storage,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
# DO NOT REMOVE BELOW
|
||||
from models import account, dataset, model, source, task, tool, web, tools
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
||||
@@ -86,6 +97,7 @@ def create_app(test_config=None) -> Flask:
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
|
||||
887
api/commands.py
887
api/commands.py
@@ -1,33 +1,22 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import click
|
||||
import qdrant_client
|
||||
from constants.languages import user_input_form_template
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from flask import Flask, current_app
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetQuery, Document
|
||||
from models.model import Account, App, AppModelConfig, Message, MessageAnnotation, InstalledApp
|
||||
from models.provider import Provider, ProviderModel, ProviderQuotaType, ProviderType
|
||||
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
|
||||
from tqdm import tqdm
|
||||
from werkzeug.exceptions import NotFound
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@@ -35,15 +24,22 @@ from werkzeug.exceptions import NotFound
|
||||
@click.option('--new-password', prompt=True, help='the new password.')
|
||||
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
|
||||
def reset_password(email, new_password, password_confirm):
|
||||
"""
|
||||
Reset password of owner account
|
||||
Only available in SELF_HOSTED mode
|
||||
"""
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
|
||||
return
|
||||
|
||||
account = db.session.query(Account). \
|
||||
filter(Account.email == email). \
|
||||
one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
@@ -69,15 +65,22 @@ def reset_password(email, new_password, password_confirm):
|
||||
@click.option('--new-email', prompt=True, help='the new email.')
|
||||
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
|
||||
def reset_email(email, new_email, email_confirm):
|
||||
"""
|
||||
Replace account email
|
||||
:return:
|
||||
"""
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
|
||||
return
|
||||
|
||||
account = db.session.query(Account). \
|
||||
filter(Account.email == email). \
|
||||
one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
@@ -97,6 +100,11 @@ def reset_email(email, new_email, email_confirm):
|
||||
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
|
||||
' this operation cannot be rolled back!', fg='red'))
|
||||
def reset_encrypt_key_pair():
|
||||
"""
|
||||
Reset the encrypted key pair of workspace for encrypt LLM credentials.
|
||||
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
||||
Only support SELF_HOSTED mode.
|
||||
"""
|
||||
if current_app.config['EDITION'] != 'SELF_HOSTED':
|
||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
||||
return
|
||||
@@ -116,657 +124,254 @@ def reset_encrypt_key_pair():
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
|
||||
|
||||
@click.command('generate-invitation-codes', help='Generate invitation codes.')
|
||||
@click.option('--batch', help='The batch of invitation codes.')
|
||||
@click.option('--count', prompt=True, help='Invitation codes count.')
|
||||
def generate_invitation_codes(batch, count):
|
||||
if not batch:
|
||||
now = datetime.datetime.now()
|
||||
batch = now.strftime('%Y%m%d%H%M%S')
|
||||
|
||||
if not count or int(count) <= 0:
|
||||
click.echo(click.style('sorry. the count must be greater than 0.', fg='red'))
|
||||
return
|
||||
|
||||
count = int(count)
|
||||
|
||||
click.echo('Start generate {} invitation codes for batch {}.'.format(count, batch))
|
||||
|
||||
codes = ''
|
||||
for i in range(count):
|
||||
code = generate_invitation_code()
|
||||
invitation_code = InvitationCode(
|
||||
code=code,
|
||||
batch=batch
|
||||
)
|
||||
db.session.add(invitation_code)
|
||||
click.echo(code)
|
||||
|
||||
codes += code + "\n"
|
||||
db.session.commit()
|
||||
|
||||
filename = 'storage/invitation-codes-{}.txt'.format(batch)
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
f.write(codes)
|
||||
|
||||
click.echo(click.style(
|
||||
'Congratulations! Generated {} invitation codes for batch {} and saved to the file \'{}\''.format(count, batch,
|
||||
filename),
|
||||
fg='green'))
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in ['knowledge', 'all']:
|
||||
migrate_knowledge_vector_database()
|
||||
if scope in ['annotation', 'all']:
|
||||
migrate_annotation_vector_database()
|
||||
|
||||
|
||||
def generate_invitation_code():
|
||||
code = generate_upper_string()
|
||||
while db.session.query(InvitationCode).filter(InvitationCode.code == code).count() > 0:
|
||||
code = generate_upper_string()
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def generate_upper_string():
|
||||
letters_digits = string.ascii_uppercase + string.digits
|
||||
result = ""
|
||||
for i in range(8):
|
||||
result += random.choice(letters_digits)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
|
||||
def recreate_all_dataset_indexes():
|
||||
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
|
||||
recreate_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
try:
|
||||
click.echo('Recreating dataset index: {}'.format(dataset.id))
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index and index._is_origin():
|
||||
index.recreate_dataset(dataset)
|
||||
recreate_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('clean-unused-dataset-indexes', help='Clean unused dataset indexes.')
|
||||
def clean_unused_dataset_indexes():
|
||||
click.echo(click.style('Start clean unused dataset indexes.', fg='green'))
|
||||
clean_days = int(current_app.config.get('CLEAN_DAY_SETTING'))
|
||||
start_at = time.perf_counter()
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
dataset_query = db.session.query(DatasetQuery).filter(
|
||||
DatasetQuery.created_at > thirty_days_ago,
|
||||
DatasetQuery.dataset_id == dataset.id
|
||||
).all()
|
||||
if not dataset_query or len(dataset_query) == 0:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.indexing_status == 'completed',
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at > thirty_days_ago
|
||||
).all()
|
||||
if not documents or len(documents) == 0:
|
||||
try:
|
||||
# remove index
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
vector_index.delete()
|
||||
kw_index.delete()
|
||||
# update document
|
||||
update_params = {
|
||||
Document.enabled: False
|
||||
}
|
||||
|
||||
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
|
||||
db.session.commit()
|
||||
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
|
||||
|
||||
|
||||
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
|
||||
def sync_anthropic_hosted_providers():
|
||||
if not hosted_model_providers.anthropic:
|
||||
click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
|
||||
return
|
||||
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
new_quota_limit = hosted_model_providers.anthropic.quota_limit
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.provider_name == 'anthropic',
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.quota_limit != new_quota_limit
|
||||
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for provider in providers:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
|
||||
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
|
||||
original_quota_limit = provider.quota_limit
|
||||
division = math.ceil(new_quota_limit / 1000)
|
||||
|
||||
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
|
||||
else original_quota_limit * division
|
||||
provider.quota_used = division * provider.quota_used
|
||||
db.session.commit()
|
||||
|
||||
count += 1
|
||||
except Exception as e:
|
||||
click.echo(click.style(
|
||||
'Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||
|
||||
|
||||
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
|
||||
def create_qdrant_indexes():
|
||||
click.echo(click.style('Start create qdrant indexes.', fg='green'))
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate annotation data.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
except Exception:
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.create_qdrant_dataset(dataset)
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {
|
||||
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('update-qdrant-indexes', help='Update qdrant indexes.')
|
||||
def update_qdrant_indexes():
|
||||
click.echo(click.style('Start Update qdrant indexes.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
# get apps info
|
||||
apps = db.session.query(App).filter(
|
||||
App.status == 'normal'
|
||||
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Update dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.update_qdrant_dataset(dataset)
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('normalization-collections', help='restore all collections in one')
|
||||
def normalization_collections():
|
||||
click.echo(click.style('Start normalization collections.', fg='green'))
|
||||
normalization_count = []
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
datasets_result = datasets.items
|
||||
page += 1
|
||||
for i in range(0, len(datasets_result), 5):
|
||||
threads = []
|
||||
sub_datasets = datasets_result[i:i + 5]
|
||||
for dataset in sub_datasets:
|
||||
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'normalization_count': normalization_count
|
||||
})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
|
||||
|
||||
|
||||
@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
|
||||
def add_qdrant_full_text_index():
|
||||
click.echo(click.style('Start add full text index.', fg='green'))
|
||||
binds = db.session.query(DatasetCollectionBinding).all()
|
||||
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
|
||||
qdrant_url = current_app.config['QDRANT_URL']
|
||||
qdrant_api_key = current_app.config['QDRANT_API_KEY']
|
||||
client = qdrant_client.QdrantClient(
|
||||
qdrant_url,
|
||||
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
|
||||
)
|
||||
for bind in binds:
|
||||
for app in apps:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} app {app.id}. '
|
||||
+ f'{create_count} created, {skipped_count} skipped.')
|
||||
try:
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True
|
||||
click.echo('Create app annotation index: {}'.format(app.id))
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo('App annotation setting is disabled: {}'.format(app.id))
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
).first()
|
||||
if not dataset_collection_binding:
|
||||
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
|
||||
continue
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
client.create_payload_index(bind.collection_name, 'page_content',
|
||||
field_schema=text_index_params)
|
||||
documents = []
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app.id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index for app: {app.id}.',
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index for app {app.id}.',
|
||||
fg='red'))
|
||||
raise e
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
|
||||
fg='green'))
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
|
||||
raise e
|
||||
click.echo(f'Successfully migrated app annotation {app.id}.')
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(
|
||||
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
|
||||
fg='green'))
|
||||
|
||||
|
||||
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
|
||||
with flask_app.app_context():
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
click.echo('restore dataset index: {}'.format(dataset.id))
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
|
||||
+ f'{create_count} created, ${skipped_count} skipped.')
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
||||
DatasetCollectionBinding.model_name == embedding_model.name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=embedding_model.model_provider.provider_name,
|
||||
model_name=embedding_model.name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
# index.delete_by_group_id(dataset.id)
|
||||
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
||||
else:
|
||||
click.echo('passed.')
|
||||
normalization_count.append(1)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
|
||||
|
||||
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def update_app_model_configs(batch_size):
|
||||
pre_prompt_template = '{{default_input}}'
|
||||
|
||||
click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
|
||||
|
||||
total_records = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.count()
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
|
||||
num_batches = (total_records + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_records, desc="Migrating Data") as pbar:
|
||||
for i in range(num_batches):
|
||||
offset = i * batch_size
|
||||
limit = min(batch_size, total_records - offset)
|
||||
|
||||
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
data_batch = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.order_by(App.created_at) \
|
||||
.offset(offset).limit(limit).all()
|
||||
|
||||
if not data_batch:
|
||||
click.secho("No more data to migrate.", fg='green')
|
||||
break
|
||||
|
||||
try:
|
||||
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
|
||||
for data in data_batch:
|
||||
# click.secho(f"Migrating data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
|
||||
|
||||
if data.pre_prompt is None:
|
||||
data.pre_prompt = pre_prompt_template
|
||||
else:
|
||||
if pre_prompt_template in data.pre_prompt:
|
||||
continue
|
||||
data.pre_prompt += pre_prompt_template
|
||||
|
||||
app_data = db.session.query(App) \
|
||||
.filter(App.id == data.app_id) \
|
||||
.one()
|
||||
|
||||
account_data = db.session.query(Account) \
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.filter(TenantAccountJoin.tenant_id == app_data.tenant_id) \
|
||||
.one_or_none()
|
||||
|
||||
if not account_data:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] == vector_type:
|
||||
skipped_count = skipped_count + 1
|
||||
continue
|
||||
|
||||
if data.user_input_form is None or data.user_input_form == 'null':
|
||||
data.user_input_form = json.dumps(user_input_form_template[account_data.interface_language])
|
||||
collection_name = ''
|
||||
if vector_type == "weaviate":
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "qdrant":
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
raw_json_data = json.loads(data.user_input_form)
|
||||
raw_json_data.append(user_input_form_template[account_data.interface_language][0])
|
||||
data.user_input_form = json.dumps(raw_json_data)
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
# click.secho(f"Updated data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
|
||||
elif vector_type == "milvus":
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"Start to migrate dataset {dataset.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='red'))
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
segments_count = 0
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
segments_count = segments_count + 1
|
||||
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
|
||||
raise e
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('migrate_default_input_to_dataset_query_variable')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
click.secho("Starting...", fg='green')
|
||||
|
||||
total_records = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.count()
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
|
||||
num_batches = (total_records + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_records, desc="Migrating Data") as pbar:
|
||||
for i in range(num_batches):
|
||||
offset = i * batch_size
|
||||
limit = min(batch_size, total_records - offset)
|
||||
|
||||
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
data_batch = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.order_by(App.created_at) \
|
||||
.offset(offset).limit(limit).all()
|
||||
|
||||
if not data_batch:
|
||||
click.secho("No more data to migrate.", fg='green')
|
||||
break
|
||||
|
||||
try:
|
||||
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
|
||||
for data in data_batch:
|
||||
config = AppModelConfig.to_dict(data)
|
||||
|
||||
tools = config["agent_mode"]["tools"]
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
if not dataset_exists:
|
||||
continue
|
||||
|
||||
user_input_form = config.get("user_input_form", [])
|
||||
for form in user_input_form:
|
||||
paragraph = form.get('paragraph')
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'query':
|
||||
data.dataset_query_variable = 'query'
|
||||
break
|
||||
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'default_input':
|
||||
data.dataset_query_variable = 'default_input'
|
||||
break
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('add-annotation-question-field-value', help='add annotation question value')
|
||||
def add_annotation_question_field_value():
|
||||
click.echo(click.style('Start add annotation question value.', fg='green'))
|
||||
message_annotations = db.session.query(MessageAnnotation).all()
|
||||
message_annotation_deal_count = 0
|
||||
if message_annotations:
|
||||
for message_annotation in message_annotations:
|
||||
try:
|
||||
if message_annotation.message_id and not message_annotation.question:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_annotation.message_id
|
||||
).first()
|
||||
message_annotation.question = message.query
|
||||
db.session.add(message_annotation)
|
||||
db.session.commit()
|
||||
message_annotation_deal_count += 1
|
||||
click.echo(f'Successfully migrated dataset {dataset.id}.')
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(
|
||||
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
|
||||
fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(generate_invitation_codes)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(recreate_all_dataset_indexes)
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
app.cli.add_command(clean_unused_dataset_indexes)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(update_qdrant_indexes)
|
||||
app.cli.add_command(update_app_model_configs)
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
app.cli.add_command(add_qdrant_full_text_index)
|
||||
app.cli.add_command(add_annotation_question_field_value)
|
||||
app.cli.add_command(vdb_migrate)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
@@ -39,18 +38,14 @@ DEFAULTS = {
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_OPENAI_PAID_MIN_QUANTITY': 1,
|
||||
'HOSTED_OPENAI_PAID_MAX_QUANTITY': 1,
|
||||
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
'HOSTED_ANTHROPIC_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 1,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 1,
|
||||
'HOSTED_MODERATION_ENABLED': 'False',
|
||||
'HOSTED_MODERATION_PROVIDERS': '',
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
@@ -63,6 +58,8 @@ DEFAULTS = {
|
||||
'BILLING_ENABLED': 'False',
|
||||
'CAN_REPLACE_LOGO': 'False',
|
||||
'ETL_TYPE': 'dify',
|
||||
'KEYWORD_STORE': 'jieba',
|
||||
'BATCH_UPLOAD_LIMIT': 20
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +90,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.1"
|
||||
self.CURRENT_VERSION = "0.5.9"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -189,7 +186,7 @@ class Config:
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
@@ -215,6 +212,12 @@ class Config:
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
self.RESEND_API_URL = get_env('RESEND_API_URL')
|
||||
# SMTP settings
|
||||
self.SMTP_SERVER = get_env('SMTP_SERVER')
|
||||
self.SMTP_PORT = get_env('SMTP_PORT')
|
||||
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
|
||||
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
|
||||
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
|
||||
|
||||
# ------------------------
|
||||
# Workpace Configurations.
|
||||
@@ -260,12 +263,10 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
|
||||
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_OPENAI_PAID_MIN_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_OPENAI_PAID_MAX_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MAX_QUANTITY'))
|
||||
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
@@ -277,10 +278,6 @@ class Config:
|
||||
self.HOSTED_ANTHROPIC_TRIAL_ENABLED = get_bool_env('HOSTED_ANTHROPIC_TRIAL_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_MINIMAX_ENABLED = get_bool_env('HOSTED_MINIMAX_ENABLED')
|
||||
self.HOSTED_SPARK_ENABLED = get_bool_env('HOSTED_SPARK_ENABLED')
|
||||
@@ -294,6 +291,10 @@ class Config:
|
||||
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
|
||||
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
|
||||
|
||||
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
|
||||
|
||||
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
|
||||
import json
|
||||
|
||||
from models.model import AppModelConfig
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
@@ -15,8 +15,10 @@ language_timezone_mapping = {
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
}
|
||||
|
||||
|
||||
def supported_language(lang):
|
||||
if lang in languages:
|
||||
return lang
|
||||
@@ -25,6 +27,7 @@ def supported_language(lang):
|
||||
.format(lang=lang))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
@@ -66,6 +69,16 @@ user_input_form_template = {
|
||||
}
|
||||
}
|
||||
],
|
||||
"ua-UK": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
demo_model_templates = {
|
||||
@@ -144,7 +157,7 @@ demo_model_templates = {
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
@@ -271,7 +284,7 @@ demo_model_templates = {
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
@@ -322,5 +335,130 @@ demo_model_templates = {
|
||||
)
|
||||
}
|
||||
],
|
||||
'uk-UA': [{
|
||||
"name": "Помічник перекладу",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
|
||||
"mode": "completion",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo-instruct",
|
||||
configs={
|
||||
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
|
||||
"prompt_variables": [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Цільова мова",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"type": "select",
|
||||
"default": "Ukrainian",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"English",
|
||||
"Japanese",
|
||||
"French",
|
||||
"Russian",
|
||||
"German",
|
||||
"Spanish",
|
||||
"Korean",
|
||||
"Italian",
|
||||
],
|
||||
},
|
||||
],
|
||||
"completion_params": {
|
||||
"max_token": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Цільова мова",
|
||||
"variable": "target_language",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
"name": "AI інтерв’юер фронтенду",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.",
|
||||
"mode": "chat",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
configs={
|
||||
"introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
"prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
"prompt_variables": [],
|
||||
"completion_params": {
|
||||
"max_token": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=None
|
||||
),
|
||||
}
|
||||
],
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
model_templates = {
|
||||
# completion default mode
|
||||
'completion_default': {
|
||||
@@ -15,30 +13,14 @@ model_templates = {
|
||||
'status': 'normal'
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'gpt-3.5-turbo-instruct',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 512,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
}
|
||||
},
|
||||
'provider': '',
|
||||
'model_id': '',
|
||||
'configs': {},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
"completion_params": {}
|
||||
}),
|
||||
'user_input_form': json.dumps([
|
||||
{
|
||||
@@ -66,30 +48,14 @@ model_templates = {
|
||||
'status': 'normal'
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'gpt-3.5-turbo',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 512,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
}
|
||||
},
|
||||
'provider': '',
|
||||
'model_id': '',
|
||||
'configs': {},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
"completion_params": {}
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
@@ -11,13 +11,11 @@ from .app import (advanced_prompt_template, annotation, app, audio, completion,
|
||||
model_config, site, statistic)
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_oauth, login, oauth
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
# Import explore controllers
|
||||
from .explore import audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
# Import operation controllers
|
||||
from .operation import operation
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from constants.languages import supported_language
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
|
||||
def admin_required(view):
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import flask_restful
|
||||
from extensions.ext_database import db
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from . import api
|
||||
from .setup import setup_required
|
||||
@@ -61,9 +62,7 @@ class BaseApiKeyListResource(Resource):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id,
|
||||
self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
@@ -102,7 +101,7 @@ class BaseApiKeyResource(Resource):
|
||||
self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.query(ApiToken). \
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import NoFileUploadedError
|
||||
from controllers.console.datasets.error import TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (annotation_fields, annotation_hit_history_fields,
|
||||
annotation_hit_history_list_fields, annotation_list_fields)
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@@ -21,7 +24,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -45,7 +48,7 @@ class AppAnnotationSettingDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -59,7 +62,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -80,7 +83,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
@@ -108,7 +111,7 @@ class AnnotationListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
@@ -133,7 +136,7 @@ class AnnotationExportApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -152,7 +155,7 @@ class AnnotationCreateApi(Resource):
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -172,7 +175,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -189,7 +192,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -205,7 +208,7 @@ class AnnotationBatchImportApi(Resource):
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -230,7 +233,7 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
@@ -257,7 +260,7 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id, annotation_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from constants.model_template import model_templates
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, inputs, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from constants.languages import demo_model_templates, languages
|
||||
from constants.model_template import model_templates
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
@@ -15,16 +18,18 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (app_detail_fields, app_detail_fields_with_site, app_pagination_fields,
|
||||
template_list_fields)
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, inputs, marshal_with, reqparse
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
app_pagination_fields,
|
||||
template_list_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from models.tools import ApiToolProvider
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
|
||||
def _get_app(app_id, tenant_id):
|
||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||
@@ -88,7 +93,7 @@ class AppListApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -107,20 +112,27 @@ class AppListApi(Resource):
|
||||
# validate config
|
||||
model_config_dict = args['model_config']
|
||||
|
||||
# get model provider
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
# Get provider configurations
|
||||
provider_manager = ProviderManager()
|
||||
provider_configurations = provider_manager.get_configurations(current_user.current_tenant_id)
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
model_type=ModelType.LLM,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
# check if model is available
|
||||
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
|
||||
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
|
||||
if provider_model not in available_models_names:
|
||||
if not default_model_entity:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Default System Reasoning Model available. Please configure "
|
||||
"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = default_model_entity.provider.provider
|
||||
model_config_dict["model"]["name"] = default_model_entity.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@@ -226,7 +238,42 @@ class AppApi(Resource):
|
||||
def get(self, app_id):
|
||||
"""Get app detail"""
|
||||
app_id = str(app_id)
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
app: App = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
# get original app model config
|
||||
model_config: AppModelConfig = app.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
return app
|
||||
|
||||
@@ -237,7 +284,7 @@ class AppApi(Resource):
|
||||
"""Delete app"""
|
||||
app_id = str(app_id)
|
||||
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
@@ -1,24 +1,35 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError,
|
||||
NoAudioUploadedError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError, UnsupportedAudioTypeError)
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from libs.login import login_required
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError)
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@@ -34,7 +45,8 @@ class ChatMessageAudioApi(Resource):
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file
|
||||
file=file,
|
||||
end_user=None,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -60,7 +72,7 @@ class ChatMessageAudioApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@@ -71,10 +83,12 @@ class ChatMessageTextApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, None)
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@@ -101,9 +115,50 @@ class ChatMessageTextApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextModesApi(Resource):
|
||||
def get(self, app_id: str):
|
||||
app_model = _get_app(str(app_id))
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('language', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args['language'],
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError:
|
||||
raise AppUnavailableError("Text to audio voices language parameter loss.")
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
|
||||
|
||||
@@ -1,27 +1,33 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.completion_service import CompletionService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
@@ -163,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (conversation_detail_fields, conversation_message_detail_fields,
|
||||
conversation_pagination_fields, conversation_with_summary_pagination_fields)
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from fields.conversation_fields import (
|
||||
conversation_detail_fields,
|
||||
conversation_message_detail_fields,
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.helper import datetime_string
|
||||
from libs.login import login_required
|
||||
from models.model import Conversation, Message, MessageAnnotation
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class CompletionConversationApi(Resource):
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderQuotaExceededError)
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.login import login_required
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import (AppMoreLikeThisDisabledError, CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
@@ -14,10 +25,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import login_required
|
||||
@@ -28,7 +35,6 @@ from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
||||
class ChatMessageListApi(Resource):
|
||||
@@ -157,7 +163,7 @@ class MessageAnnotationApi(Resource):
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -241,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from libs.login import login_required
|
||||
from models.model import AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
@@ -38,6 +42,88 @@ class ModelConfigResource(Resource):
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
# get original app model config
|
||||
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == app.app_model_config_id
|
||||
).first()
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
parameter_map = {}
|
||||
masked_parameter_map = {}
|
||||
tool_map = {}
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
parameters = {}
|
||||
masked_parameter = {}
|
||||
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
masked_parameter_map[key] = masked_parameter
|
||||
parameter_map[key] = parameters
|
||||
tool_map[key] = tool_runtime
|
||||
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
|
||||
# get tool
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
if key in tool_map:
|
||||
tool_runtime = tool_map[key]
|
||||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
# override parameters if it equals to masked parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
|
||||
# update app model config
|
||||
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from constants.languages import supported_language
|
||||
from libs.login import login_required
|
||||
from models.model import Site
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
@@ -42,7 +42,7 @@ class AppSite(Resource):
|
||||
app_model = _get_app(app_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site). \
|
||||
@@ -88,7 +88,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
app_model = _get_app(app_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import datetime_string
|
||||
from libs.login import login_required
|
||||
|
||||
|
||||
@@ -2,14 +2,15 @@ import base64
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import email, str_len, timezone
|
||||
from constants.languages import supported_language
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus, Tenant
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from controllers.console import api
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
@@ -30,7 +31,7 @@ def get_oauth_providers():
|
||||
class OAuthDataSource(Resource):
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask
|
||||
import flask_login
|
||||
from flask import current_app, request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from flask import current_app, request
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService, TenantService
|
||||
@@ -30,10 +29,7 @@ class LoginApi(Resource):
|
||||
except services.errors.account.AccountLoginError:
|
||||
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
||||
|
||||
try:
|
||||
TenantService.switch_tenant(account)
|
||||
except Exception:
|
||||
pass
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
@@ -47,7 +43,6 @@ class LogoutApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
flask.session.pop('workspace_id', None)
|
||||
flask_login.logout_user()
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
@@ -3,13 +3,14 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
||||
from .. import api
|
||||
|
||||
@@ -75,6 +76,8 @@ class OAuthCallback(Resource):
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@@ -20,7 +21,7 @@ class Subscription(Resource):
|
||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||
args = parser.parse_args()
|
||||
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
return BillingService.get_subscription(args['plan'],
|
||||
args['interval'],
|
||||
@@ -35,8 +36,8 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
return BillingService.get_invoices(current_user.email)
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
api.add_resource(Subscription, '/billing/subscription')
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from libs.login import login_required
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
@@ -172,14 +174,15 @@ class DataSourceNotionApi(Resource):
|
||||
if not data_source_binding:
|
||||
raise NotFound('Data source binding not found.')
|
||||
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
text_docs = loader.load()
|
||||
text_docs = extractor.extract()
|
||||
return {
|
||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||
}, 200
|
||||
@@ -191,11 +194,31 @@ class DataSourceNotionApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = args['notion_info_list']
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'])
|
||||
return response, 200
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_restful
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.apikey import api_key_fields, api_key_list
|
||||
@@ -11,18 +15,15 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@@ -103,7 +104,7 @@ class DatasetListApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -178,16 +179,16 @@ class DatasetApi(Resource):
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
dataset = DatasetService.update_dataset(
|
||||
@@ -205,7 +206,7 @@ class DatasetApi(Resource):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
@@ -278,37 +280,45 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
notion_info_list = args['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@@ -391,7 +401,7 @@ class DatasetApiKeyApi(Resource):
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
@@ -425,7 +435,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
api_key_id = str(api_key_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.query(ApiToken). \
|
||||
@@ -508,4 +518,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
|
||||
|
||||
@@ -1,36 +1,52 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.console.datasets.error import (ArchivedDocumentImmutableError, DocumentAlreadyFinishedError,
|
||||
DocumentIndexingError, InvalidActionError, InvalidMetadataError)
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentAlreadyFinishedError,
|
||||
DocumentIndexingError,
|
||||
InvalidActionError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError,
|
||||
QuotaExceededError)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import (dataset_and_document_fields, document_fields, document_status_fields,
|
||||
document_with_segments_fields)
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
document_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from sqlalchemy import asc, desc
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
@@ -54,7 +70,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
@@ -80,7 +96,7 @@ class GetProcessRuleApi(Resource):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get('document_id')
|
||||
|
||||
|
||||
# get default rules
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
@@ -204,7 +220,7 @@ class DatasetDocumentListApi(Resource):
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -256,7 +272,7 @@ class DatasetInitApi(Resource):
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -279,8 +295,8 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -347,16 +363,22 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
if not file:
|
||||
raise NotFound('File not found.')
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -387,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@@ -409,42 +432,49 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if dataset.data_source_type == 'upload_file':
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(info_list)
|
||||
).all()
|
||||
if document.data_source_type == 'upload_file':
|
||||
file_id = data_source_info['upload_file_id']
|
||||
file_detail = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == 'notion_import':
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict,
|
||||
None, 'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
|
||||
|
||||
@@ -599,7 +629,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if action == "pause":
|
||||
@@ -663,7 +693,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
doc_metadata = req_data.get('doc_metadata')
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if doc_type is None or doc_metadata is None:
|
||||
@@ -710,7 +740,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
indexing_cache_key = 'document_{}_indexing'.format(document.id)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@@ -15,16 +19,12 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import segment_fields
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from libs.login import login_required
|
||||
from models.dataset import DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@@ -123,7 +123,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -142,8 +142,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -219,7 +219,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
@@ -233,8 +233,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
@@ -285,8 +285,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
@@ -298,7 +298,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
@@ -342,7 +342,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError,
|
||||
UnsupportedFileTypeError)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService, ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS
|
||||
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
@@ -34,6 +39,7 @@ class FileApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check(resource='documents')
|
||||
def post(self):
|
||||
|
||||
# get file from request
|
||||
|
||||
@@ -1,22 +1,31 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderQuotaExceededError)
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError,
|
||||
QuotaExceededError)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
@@ -67,8 +76,8 @@ class HitTestingApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -13,6 +13,16 @@ class NotSetupError(BaseHTTPException):
|
||||
"Please proceed with the initialization and installation process first."
|
||||
code = 401
|
||||
|
||||
class NotInitValidateError(BaseHTTPException):
|
||||
error_code = 'not_init_validated'
|
||||
description = "Init validation has not been completed yet. " \
|
||||
"Please proceed with the init validation process first."
|
||||
code = 401
|
||||
|
||||
class InitValidateFailedError(BaseHTTPException):
|
||||
error_code = 'init_validate_failed'
|
||||
description = "Init validation failed. Please check the password and try again."
|
||||
code = 401
|
||||
|
||||
class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = 'account_not_link_tenant'
|
||||
|
||||
@@ -1,21 +1,32 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError,
|
||||
NoAudioUploadedError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError, UnsupportedAudioTypeError)
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import request
|
||||
from models.model import AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError)
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
@@ -74,6 +85,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Generator, Union
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
@@ -16,12 +26,8 @@ from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@@ -158,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from libs.login import login_required
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
|
||||
class InstalledAppsListApi(Resource):
|
||||
|
||||
@@ -1,31 +1,39 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (AppMoreLikeThisDisabledError, CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.console.explore.error import (AppSuggestedQuestionsAfterAnswerDisabledError, NotChatAppError,
|
||||
NotCompletionAppError)
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@@ -115,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import InstalledApp, AppModelConfig
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppModelConfig, InstalledApp
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -77,7 +77,7 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from libs.login import login_required
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from sqlalchemy import and_
|
||||
from constants.languages import languages
|
||||
|
||||
app_fields = {
|
||||
'id': fields.String,
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from libs.login import login_required
|
||||
from models.model import InstalledApp
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
def installed_app_required(view=None):
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from libs.login import login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api
|
||||
from .wraps import cloud_utm_record
|
||||
|
||||
|
||||
class FeatureApi(Resource):
|
||||
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
return FeatureService.get_features(current_user.current_tenant_id).dict()
|
||||
|
||||
|
||||
49
api/controllers/console/init_validate.py
Normal file
49
api/controllers/console/init_validate.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
|
||||
from flask import current_app, session
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from libs.helper import str_len
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
from . import api
|
||||
from .error import AlreadySetupError, InitValidateFailedError
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
class InitValidateAPI(Resource):
|
||||
|
||||
def get(self):
|
||||
init_status = get_init_validate_status()
|
||||
if init_status:
|
||||
return { 'status': 'finished' }
|
||||
return {'status': 'not_started' }
|
||||
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
# is tenant created
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('password', type=str_len(30),
|
||||
required=True, location='json')
|
||||
input_password = parser.parse_args()['password']
|
||||
|
||||
if input_password != os.environ.get('INIT_PASSWORD'):
|
||||
session['is_init_validated'] = False
|
||||
raise InitValidateFailedError()
|
||||
|
||||
session['is_init_validated'] = True
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
def get_init_validate_status():
|
||||
if current_app.config['EDITION'] == 'SELF_HOSTED':
|
||||
if os.environ.get('INIT_PASSWORD'):
|
||||
return session.get('is_init_validated') or DifySetup.query.first()
|
||||
|
||||
return True
|
||||
|
||||
api.add_resource(InitValidateAPI, '/init')
|
||||
@@ -1,30 +0,0 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud
|
||||
from libs.login import login_required
|
||||
from services.operation_service import OperationService
|
||||
|
||||
|
||||
class TenantUtm(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('utm_source', type=str, required=True)
|
||||
parser.add_argument('utm_medium', type=str, required=True)
|
||||
parser.add_argument('utm_campaign', type=str, required=False, default='')
|
||||
parser.add_argument('utm_content', type=str, required=False, default='')
|
||||
parser.add_argument('utm_term', type=str, required=False, default='')
|
||||
args = parser.parse_args()
|
||||
|
||||
return OperationService.record_utm(current_user.current_tenant_id, args)
|
||||
|
||||
|
||||
api.add_resource(TenantUtm, '/operation/utm')
|
||||
@@ -1,16 +1,17 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app, request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
||||
from . import api
|
||||
from .error import AlreadySetupError, NotSetupError
|
||||
from .error import AlreadySetupError, NotInitValidateError, NotSetupError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
@@ -24,7 +25,7 @@ class SetupApi(Resource):
|
||||
'step': 'finished',
|
||||
'setup_at': setup_status.setup_at.isoformat()
|
||||
}
|
||||
return {'step': 'not_start'}
|
||||
return {'step': 'not_started'}
|
||||
return {'step': 'finished'}
|
||||
|
||||
@only_edition_self_hosted
|
||||
@@ -37,6 +38,9 @@ class SetupApi(Resource):
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=email,
|
||||
@@ -71,7 +75,10 @@ def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if not get_setup_status():
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
elif not get_setup_status():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import json
|
||||
import logging
|
||||
@@ -6,7 +5,6 @@ import logging
|
||||
import requests
|
||||
from flask import current_app
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from . import api
|
||||
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import (AccountAlreadyInitedError, CurrentPasswordIncorrectError,
|
||||
InvalidInvitationCodeError, RepeatPasswordNotMatchError)
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from libs.helper import TimestampField, timezone
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import (
|
||||
AccountAlreadyInitedError,
|
||||
CurrentPasswordIncorrectError,
|
||||
InvalidInvitationCodeError,
|
||||
RepeatPasswordNotMatchError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField, timezone
|
||||
from libs.login import login_required
|
||||
from models.account import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||
@@ -12,6 +11,7 @@ from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
@@ -52,10 +52,12 @@ class MemberInviteEmailApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('emails', type=str, required=True, location='json', action='append')
|
||||
parser.add_argument('role', type=str, required=True, default='admin', location='json')
|
||||
parser.add_argument('language', type=str, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
invitee_emails = args['emails']
|
||||
invitee_role = args['role']
|
||||
interface_language = args['language']
|
||||
if invitee_role not in ['admin', 'normal']:
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
@@ -64,13 +66,19 @@ class MemberInviteEmailApi(Resource):
|
||||
console_web_url = current_app.config.get("CONSOLE_WEB_URL")
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
|
||||
})
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/signin'
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
invitation_results.append({
|
||||
'status': 'failed',
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import io
|
||||
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
class ModelProviderListApi(Resource):
|
||||
@@ -98,7 +99,7 @@ class ModelProviderApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -122,7 +123,7 @@ class ModelProviderApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@@ -159,7 +160,7 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
@@ -186,10 +187,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
def get(self, provider: str):
|
||||
if provider != 'anthropic':
|
||||
raise ValueError(f'provider name {provider} is invalid')
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account_id=current_user.id)
|
||||
account_id=current_user.id,
|
||||
prefilled_email=current_user.email)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.login import login_required
|
||||
from services.model_provider_service import ModelProviderService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
class DefaultModelApi(Resource):
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import json
|
||||
import io
|
||||
|
||||
from libs.login import login_required
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask import send_file
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
|
||||
from libs.login import login_required
|
||||
from services.tools_manage_service import ToolManageService
|
||||
|
||||
import io
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@@ -43,7 +41,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
@@ -60,7 +58,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
@@ -84,13 +82,37 @@ class ToolBuiltinProviderIconApi(Resource):
|
||||
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
|
||||
|
||||
class ToolModelProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
|
||||
|
||||
class ToolModelProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.list_model_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
@@ -159,7 +181,7 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
@@ -171,8 +193,8 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -193,7 +215,7 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
@@ -261,6 +283,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
@@ -270,6 +293,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
|
||||
return ToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args['provider_name'] if args['provider_name'] else '',
|
||||
args['tool_name'],
|
||||
args['credentials'],
|
||||
args['parameters'],
|
||||
@@ -283,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
||||
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
|
||||
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
|
||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.datasets.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError,
|
||||
UnsupportedFileTypeError)
|
||||
from controllers.console.datasets.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Tenant
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from flask import abort, current_app
|
||||
from flask import abort, current_app, request
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from services.feature_service import FeatureService
|
||||
from services.operation_service import OperationService
|
||||
|
||||
|
||||
def account_initialization_required(view):
|
||||
@@ -54,6 +56,7 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
|
||||
if resource == 'members' and 0 < members.limit <= members.size:
|
||||
@@ -62,6 +65,13 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
abort(403, error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||
source = request.args.get('source')
|
||||
if source == 'datasets':
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
elif resource == 'workspace_custom' and not features.can_replace_logo:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
||||
@@ -73,3 +83,20 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
return decorated
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
try:
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get('utm_info')
|
||||
|
||||
if utm_info:
|
||||
utm_info = json.loads(utm_info)
|
||||
OperationService.record_utm(current_user.current_tenant_id, utm_info)
|
||||
except Exception as e:
|
||||
pass
|
||||
return view(*args, **kwargs)
|
||||
return decorated
|
||||
|
||||
@@ -6,5 +6,4 @@ bp = Blueprint('files', __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview
|
||||
from . import tool_files
|
||||
from . import image_preview, tool_files
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import services
|
||||
from controllers.files import api
|
||||
from flask import Response, request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.files import api
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
@@ -40,7 +41,7 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
raise NotFound(f'webapp logo is not found')
|
||||
raise NotFound('webapp logo is not found')
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from controllers.files import api
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.exception import BaseHTTPException
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import api
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class ToolFilePreviewApi(Resource):
|
||||
def get(self, file_id, extension):
|
||||
@@ -31,7 +32,7 @@ class ToolFilePreviewApi(Resource):
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise NotFound(f'file is not found')
|
||||
raise NotFound('file is not found')
|
||||
|
||||
generator, mimetype = result
|
||||
except Exception:
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def create_or_update_end_user_for_user_id(app_model, user_id):
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import fields, marshal_with, Resource
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
@@ -43,8 +42,9 @@ class AppParameterApi(AppApiResource):
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
@@ -65,8 +65,9 @@ class AppParameterApi(AppApiResource):
|
||||
}
|
||||
}
|
||||
|
||||
class AppMetaApi(AppApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
class AppMetaApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
@@ -78,7 +79,7 @@ class AppMetaApi(AppApiResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,25 +1,38 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError,
|
||||
NoAudioUploadedError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError, UnsupportedAudioTypeError)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from models.model import App, AppModelConfig
|
||||
from models.model import App, AppModelConfig, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError)
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
class AudioApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class AudioApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
@@ -61,11 +74,11 @@ class AudioApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -73,7 +86,8 @@ class TextApi(AppApiResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
end_user=end_user,
|
||||
voice=args['voice'] if args['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
|
||||
@@ -1,27 +1,36 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
|
||||
NotChatAppError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderQuotaExceededError)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import Response, stream_with_context, request
|
||||
from flask_restful import reqparse
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.completion_service import CompletionService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class CompletionApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class CompletionApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
@@ -30,16 +39,12 @@ class CompletionApi(AppApiResource):
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
@@ -74,25 +79,20 @@ class CompletionApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class CompletionStopApi(AppApiResource):
|
||||
def post(self, app_model, _, task_id):
|
||||
class CompletionStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_user_id = args.get('user')
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ChatApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class ChatApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -102,7 +102,6 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
@@ -110,9 +109,6 @@ class ChatApi(AppApiResource):
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
@@ -145,14 +141,13 @@ class ChatApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class ChatStopApi(AppApiResource):
|
||||
def post(self, app_model, _, task_id):
|
||||
class ChatStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
end_user_id = request.get_json().get('user')
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -162,8 +157,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,52 +1,44 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from flask import request
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
class ConversationApi(AppApiResource):
|
||||
|
||||
class ConversationApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
class ConversationDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
user = request.get_json().get('user')
|
||||
|
||||
if end_user is None and user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
class ConversationRenameApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError,
|
||||
UnsupportedFileTypeError)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.app.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.file_fields import file_fields
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(AppApiResource):
|
||||
class FileApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
|
||||
file = request.files['file']
|
||||
user_args = request.form.get('user')
|
||||
|
||||
if end_user is None and user_args is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import EndUser, Message
|
||||
from services.message_service import MessageService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.message_service import MessageService
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
|
||||
class MessageListApi(Resource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, end_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(AppApiResource):
|
||||
def post(self, app_model, end_user, message_id):
|
||||
class MessageFeedbackApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageSuggestedApi(AppApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
class MessageSuggestedApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
try:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
).first()
|
||||
|
||||
if end_user is None and message.from_end_user_id is not None:
|
||||
user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.id == message.from_end_user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
else:
|
||||
user = end_user
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
check_enabled=False
|
||||
)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from models.dataset import Dataset
|
||||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
@@ -6,9 +8,8 @@ from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
|
||||
@@ -1,29 +1,34 @@
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import (ArchivedDocumentImmutableError, DocumentIndexingError,
|
||||
NoFileUploadedError, TooManyFilesError)
|
||||
from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
)
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_service import DocumentService
|
||||
from services.file_service import FileService
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
@cloud_edition_billing_resource_check('documents', 'dataset')
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -149,6 +154,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
class DocumentAddByFileApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
@cloud_edition_billing_resource_check('documents', 'dataset')
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
@@ -6,11 +10,8 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import segment_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class SegmentApi(DatasetApiResource):
|
||||
@@ -45,8 +46,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
@@ -89,8 +90,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -181,8 +182,8 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
@@ -199,8 +200,8 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
SegmentService.segment_create_args_validate(args['segments'], document)
|
||||
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
|
||||
@@ -1,22 +1,40 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import ApiToken, App
|
||||
from services.feature_service import FeatureService
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
def validate_app_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
||||
class WhereisUserArg(Enum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
JSON = 'json'
|
||||
FORM = 'form'
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
fetch_from: WhereisUserArg
|
||||
required: bool = False
|
||||
|
||||
|
||||
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('app')
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
@@ -29,16 +47,35 @@ def validate_app_token(view=None):
|
||||
if not app_model.enable_api:
|
||||
raise NotFound()
|
||||
|
||||
return view(app_model, None, *args, **kwargs)
|
||||
return decorated
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
if view:
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get('user')
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str,
|
||||
api_token_type: str,
|
||||
@@ -52,6 +89,7 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == 'members' and 0 < members.limit <= members.size:
|
||||
raise Unauthorized(error_msg)
|
||||
@@ -59,6 +97,8 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Unauthorized(error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -76,7 +116,7 @@ def validate_dataset_token(view=None):
|
||||
.filter(Tenant.id == api_token.tenant_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role.in_(['owner'])) \
|
||||
.one_or_none()
|
||||
.one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
@@ -86,9 +126,9 @@ def validate_dataset_token(view=None):
|
||||
current_app.login_manager._update_request_context_with_user(account)
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account is not exist.")
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant is not exist.")
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
@@ -128,8 +168,33 @@ def validate_and_get_api_token(scope=None):
|
||||
return api_token
|
||||
|
||||
|
||||
class AppApiResource(Resource):
|
||||
method_decorators = [validate_app_token]
|
||||
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = 'DEFAULT-USER'
|
||||
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -77,7 +76,7 @@ class AppMeta(WebApiResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,21 +1,32 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError,
|
||||
NoAudioUploadedError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError, UnsupportedAudioTypeError)
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import request
|
||||
from models.model import App, AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError)
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
class AudioApi(WebApiResource):
|
||||
@@ -57,17 +68,23 @@ class AudioApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.text_to_speech_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@@ -94,7 +111,7 @@ class TextApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,31 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
|
||||
NotChatAppError, NotCompletionAppError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderQuotaExceededError)
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@@ -146,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class ConversationListApi(WebApiResource):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import file_fields
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
|
||||
@@ -1,30 +1,37 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import (AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
CompletionRequestError, NotChatAppError, NotCompletionAppError,
|
||||
ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
CompletionRequestError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class MessageListApi(WebApiResource):
|
||||
@@ -153,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web import api
|
||||
from extensions.ext_database import db
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
|
||||
class PassportResource(Resource):
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import Site
|
||||
from services.feature_service import FeatureService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
class AppSiteApi(WebApiResource):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from extensions.ext_database import db
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
|
||||
def validate_jwt_token(view=None):
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from typing import List, cast
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param model_config:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return 0
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
class ExceededLLMTokensLimitError(Exception):
|
||||
pass
|
||||
@@ -1,353 +0,0 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, SystemMessage,
|
||||
get_buffer_string)
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
model_config: ModelConfigEntity
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = 0
|
||||
for parameter_rule in self.model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
|
||||
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
self.model_config.parameters['max_tokens'] = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.model_config.parameters['max_tokens'] = original_max_tokens
|
||||
|
||||
return True if result.message.tool_calls else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_config.provider == 'azure_openai':
|
||||
model = model_config.model
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_config.credentials.get("base_model_name")
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
except KeyError:
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
@@ -1,297 +0,0 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, OutputParserException,
|
||||
get_buffer_string)
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observatons
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||
|
||||
messages = []
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
||||
|
||||
self.moving_summary_index = len(intermediate_steps)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
if 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,9 +1,20 @@
|
||||
import time
|
||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import AppOrchestrationConfigEntity, ModelConfigEntity, \
|
||||
PromptTemplateEntity, ExternalDataVariableEntity, ApplicationGenerateEntity, InvokeFrom
|
||||
from core.entities.application_entities import (
|
||||
ApplicationGenerateEntity,
|
||||
AppOrchestrationConfigEntity,
|
||||
ExternalDataVariableEntity,
|
||||
InvokeFrom,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -11,12 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App, MessageAnnotation, Message
|
||||
from models.model import App, Message, MessageAnnotation
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
@@ -76,8 +84,8 @@ class AppRunner:
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: List[PromptMessage]):
|
||||
def recalc_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: list[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
@@ -119,7 +127,7 @@ class AppRunner:
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
@@ -288,7 +296,7 @@ class AppRunner:
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity
|
||||
from core.features.assistant_cot_runner import AssistantCotApplicationRunner
|
||||
from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||
AgentEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,7 +37,7 @@ class AssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
raise ValueError("App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
@@ -169,15 +168,10 @@ class AssistantApplicationRunner(AppRunner):
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tanent_id=application_generate_entity.tenant_id)
|
||||
tenant_id=application_generate_entity.tenant_id)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
@@ -194,6 +188,17 @@ class AssistantApplicationRunner(AppRunner):
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
db.session.refresh(conversation)
|
||||
db.session.refresh(message)
|
||||
db.session.close()
|
||||
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
||||
@@ -209,12 +214,13 @@ class AssistantApplicationRunner(AppRunner):
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
)
|
||||
invoke_result = assistant_cot_runner.run(
|
||||
model_instance=model_instance,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
)
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
assistant_fc_runner = AssistantFunctionCallApplicationRunner(
|
||||
@@ -229,10 +235,10 @@ class AssistantApplicationRunner(AppRunner):
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
)
|
||||
invoke_result = assistant_fc_runner.run(
|
||||
model_instance=model_instance,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
@@ -246,13 +252,13 @@ class AssistantApplicationRunner(AppRunner):
|
||||
agent=True
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tanent_id: str) -> ToolConversationVariables:
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tanent_id
|
||||
ToolConversationVariables.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if tool_variables:
|
||||
@@ -263,7 +269,7 @@ class AssistantApplicationRunner(AppRunner):
|
||||
tool_variables = ToolConversationVariables(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tanent_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str='[]',
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
@@ -282,38 +288,6 @@ class AssistantApplicationRunner(AppRunner):
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,8 @@ from typing import Optional
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import (ApplicationGenerateEntity, DatasetEntity,
|
||||
InvokeFrom, ModelConfigEntity)
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
|
||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
@@ -36,7 +35,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
raise ValueError("App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
@@ -182,7 +181,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
@@ -193,6 +192,8 @@ class BasicApplicationRunner(AppRunner):
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
|
||||
@@ -1,30 +1,45 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Generator, Optional, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
|
||||
from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
|
||||
QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
|
||||
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent,
|
||||
QueueMessageFileEvent, QueueAgentMessageEvent)
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.entities.queue_entities import (
|
||||
AnnotationReplyEvent,
|
||||
QueueAgentMessageEvent,
|
||||
QueueAgentThoughtEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
|
||||
PromptMessage, PromptMessageContentType, PromptMessageRole,
|
||||
TextPromptMessageContent)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from pydantic import BaseModel
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -74,6 +89,10 @@ class GenerateTaskPipeline:
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
db.session.refresh(self._conversation)
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
|
||||
if stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
@@ -104,7 +123,7 @@ class GenerateTaskPipeline:
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@@ -160,7 +179,7 @@ class GenerateTaskPipeline:
|
||||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'answer': self._task_state.llm_result.message.content,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
@@ -187,7 +206,7 @@ class GenerateTaskPipeline:
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@@ -288,6 +307,7 @@ class GenerateTaskPipeline:
|
||||
.first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
@@ -315,6 +335,8 @@ class GenerateTaskPipeline:
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
@@ -339,7 +361,7 @@ class GenerateTaskPipeline:
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
||||
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
|
||||
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
@@ -398,6 +420,7 @@ class GenerateTaskPipeline:
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
@@ -463,44 +486,34 @@ class GenerateTaskPipeline:
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
if isinstance(e, ValueError):
|
||||
data = {
|
||||
'code': 'invalid_param',
|
||||
'message': str(e),
|
||||
'status': 400
|
||||
}
|
||||
elif isinstance(e, ProviderTokenNotInitError):
|
||||
data = {
|
||||
'code': 'provider_not_initialize',
|
||||
'message': e.description,
|
||||
'status': 400
|
||||
}
|
||||
elif isinstance(e, QuotaExceededError):
|
||||
data = {
|
||||
error_responses = {
|
||||
ValueError: {'code': 'invalid_param', 'status': 400},
|
||||
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
|
||||
QuotaExceededError: {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
}
|
||||
elif isinstance(e, ModelCurrentlyNotSupportError):
|
||||
data = {
|
||||
'code': 'model_currently_not_support',
|
||||
'message': e.description,
|
||||
'status': 400
|
||||
}
|
||||
elif isinstance(e, InvokeError):
|
||||
data = {
|
||||
'code': 'completion_request_error',
|
||||
'message': e.description,
|
||||
'status': 400
|
||||
}
|
||||
},
|
||||
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||
InvokeError: {'code': 'completion_request_error', 'status': 400}
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
if data:
|
||||
data.setdefault('message', getattr(e, 'description', str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
'event': 'error',
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user