Compare commits

..

125 Commits
0.5.5 ... 0.5.8

Author SHA1 Message Date
takatost
534802b761 bump version to 0.5.8 (#2685) 2024-03-05 01:37:53 +08:00
takatost
5c258e212c feat: add Anthropic claude-3 models support (#2684) 2024-03-05 01:37:42 +08:00
Charlie.Wei
6a6133c102 Fix voice selection (#2664)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-04 17:50:06 +08:00
Joel
3c1825187a fix: auto generate prompt result not show (#2678) 2024-03-04 17:36:11 +08:00
Joshua
8523b34be7 add jina-reranker-v1-base-en (#2676) 2024-03-04 17:31:01 +08:00
Bowen Liang
65cfd4360a fix: typo in wecom tool (#2674) 2024-03-04 17:25:42 +08:00
Joel
bbf5f42c87 fix: CE edition limits upload file nums (#2677) 2024-03-04 17:25:31 +08:00
Jyong
3631e53ff0 Feat/add annotation migrate (#2675)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-04 17:22:06 +08:00
waltcow
f322d9bddb Fix vdb merge error (#2650) 2024-03-04 16:35:50 +08:00
Yeuoly
05ce7b9d5e fix: deep copy customColletion (#2673) 2024-03-04 15:20:20 +08:00
Yeuoly
72ddedfc5c fix: setup default filters while add credentials (#2669) 2024-03-04 14:17:00 +08:00
Yeuoly
36686d7425 fix: test custom tool already exists without decrypting credentials (#2668) 2024-03-04 14:16:47 +08:00
cola
34387ec0f1 fix typo recale to recalc (#2670) 2024-03-04 14:15:53 +08:00
Chenhe Gu
83a6b0c626 Doc/update license (#2666) 2024-03-04 14:10:39 +08:00
takatost
76da66fb7e fix: fix import from explore apps err when OpenAI not inited (#2671) 2024-03-04 14:06:54 +08:00
nan jiang
607f9eda35 Fix/app runner typo (#2661) 2024-03-04 13:32:17 +08:00
Bowen Liang
f25cec265d feat: add Wecom(企业微信) tool for sending message to chat group bot via webhook (#2638) 2024-03-04 10:27:20 +08:00
Garfield Dai
8e66b96221 Feat: Add documents limitation (#2662) 2024-03-03 12:45:06 +08:00
crazywoola
b5c1bb346c Add PubMed to tools (#2652) 2024-03-03 12:44:13 +08:00
Yeuoly
e94b323e6c fix: use English as the default i18n language (#2663) 2024-03-03 12:35:28 +08:00
nan jiang
bc65ee10c0 bugfix: model str maybe empty (#2660) 2024-03-03 11:43:38 +08:00
Rozstone
2001483659 fix: default to allcategories when search params is not from recommended (#2653) 2024-03-02 17:11:25 +08:00
crazywoola
444aba55dd Feat/jpn support (#2651) 2024-03-02 13:47:51 +08:00
Joel
3f640b1037 fix: click tool item in app debug page would show detail (#2644) 2024-03-01 18:47:12 +08:00
Yeuoly
b07084711c fix: missing description (#2643) 2024-03-01 18:19:04 +08:00
Joel
fa8ab2134f feat: displaying the tool description when clicking on a custom tool (#2642) 2024-03-01 17:58:38 +08:00
takatost
1a677da792 fix: custom tool max tool (#2641) 2024-03-01 16:43:47 +08:00
taokuizu
b6d61a818e fix: Replace path.join with urljoin. (#2631) 2024-03-01 13:07:15 +08:00
Bowen Liang
8495ffaa45 fix: typo in gaode tool (#2636) 2024-03-01 10:12:48 +08:00
Yash Parmar
dbd1d79770 FEAT: Add arxiv tool for searching scientific papers and articles fro… (#2632) 2024-02-29 19:46:10 +08:00
takatost
1910178199 fix: default mail type invalid in .env.example (#2628) 2024-02-29 17:29:48 +08:00
Bowen Liang
839a6a2c8a add logs for vdb-migrate command (#2626) 2024-02-29 16:24:51 +08:00
Yeuoly
a769edbc89 Fix/custom tool any of (#2625) 2024-02-29 14:39:05 +08:00
Yeuoly
57ffecd0e5 fix: remove unnecessary credentials of custom tool (#2621) 2024-02-29 12:58:12 +08:00
Bowen Liang
801d135390 generalize the generation of new collection name by dataset id (#2620) 2024-02-29 12:47:10 +08:00
Bowen Liang
0428f44113 chore: bump superlinter action from v5 to v6 (#2325) 2024-02-29 12:45:06 +08:00
zxhlyh
7beff3fd5a fix: model parameter load presets config (#2622) 2024-02-29 12:43:46 +08:00
takatost
88a095e40e fix: wrong default model parameters when creating app (#2623) 2024-02-29 12:43:07 +08:00
takatost
dd961985f0 refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614) 2024-02-28 23:32:47 +08:00
Yeuoly
d44b05a9e5 feat: support auth type like basic bearer and custom (#2613) 2024-02-28 23:19:08 +08:00
takatost
5bd3b02be6 version to 0.5.7 (#2610) 2024-02-28 18:07:13 +08:00
crazywoola
3cf5c1853d Fix: default button behavior (#2609) 2024-02-28 17:34:20 +08:00
takatost
a4d86496e1 fix: notion extractor raise 'NoneType' object has no attribute 'curre… (#2608) 2024-02-28 17:08:27 +08:00
takatost
90bdc85f8c fix: AppParameterApi.get() got an unexpected keyword argument 'end_user' (#2607) 2024-02-28 16:46:50 +08:00
takatost
0828873b52 fix: missing default user for APP service api (#2606) 2024-02-28 16:09:56 +08:00
crazywoola
816b707a16 Fix: explore apps is not shown (#2604) 2024-02-28 15:43:42 +08:00
crazywoola
c9257ab4bf Fix/2559 upload powered by brand image not showing up (#2602) 2024-02-28 15:17:49 +08:00
cola
69ce3b3d33 fix props.appDetail.api_base_url /v1 repeat error (#2601) 2024-02-28 15:13:38 +08:00
crazywoola
c4caa7c401 doc: props.appDetail.api_base_url (#2597) 2024-02-28 13:40:57 +08:00
Joshua
dc93a292c3 Feat/provider mistralai (#2598) 2024-02-28 13:39:55 +08:00
takatost
174ee1b646 fix: parameter user exceeded max length when invoking moonshot llm (#2596) 2024-02-28 12:23:34 +08:00
Joshua
9b1c4f47fb feat:add mistral ai (#2594) 2024-02-28 12:22:57 +08:00
crazywoola
582ba45c00 Fix 500 error when creating from the template and the provider is None (#2591) 2024-02-28 11:27:17 +08:00
Rozstone
f1cbd55007 enhancement: skip fetching to improve user experience when switching … (#2580) 2024-02-27 19:16:22 +08:00
Yeuoly
3a34370422 fix: convert tool messages into user messages in react mode and fill … (#2584) 2024-02-27 19:15:07 +08:00
Bowen Liang
29ab244de6 fix: correct the parent class of CacheEmbedding (#2578) 2024-02-27 18:05:48 +08:00
Jyong
920b2c2b40 Fix/hit test tsne issue (#2581)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-27 17:30:52 +08:00
Yeuoly
ac96d192a6 fix: parameter type handling in API tool and parser (#2574) 2024-02-27 15:59:11 +08:00
Rozstone
07fbeb6cf0 enhancement: improve client-side code (#2568) 2024-02-27 15:58:57 +08:00
Jyong
fc64cdee64 fix mivlus delete by ids error (#2573)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-27 12:23:13 +08:00
zxhlyh
0c0e96c55f fix: notion binding (#2572) 2024-02-27 11:59:54 +08:00
Jyong
5b953c1ef2 Fix some RAG bugs (#2570)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-27 11:39:05 +08:00
Bowen Liang
562ca45e07 fix weaviate delete_by_ids (#2565) 2024-02-27 11:14:35 +08:00
crazywoola
6bbd53512e Add Dify Meetup Event on Mar 9 (#2566) 2024-02-27 10:40:26 +08:00
Bowen Liang
e352a8ed1b chore: remove redundant casting flask app config into dict (#2564) 2024-02-27 09:39:26 +08:00
Bowen Liang
e55225e2bc fix typo in error message of supported keyword store (#2560) 2024-02-27 00:47:36 +08:00
Yeuoly
3e63abd335 Feat/json mode (#2563) 2024-02-26 23:34:40 +08:00
Jyong
0620fa3094 Feat/vdb migrate command (#2562)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-26 19:47:29 +08:00
Rozstone
d93288f711 Feat/use searchparams as state (#2554)
Co-authored-by: crazywoola <427733928@qq.com>
2024-02-26 12:52:59 +08:00
Rozstone
ca69af7b97 feat: change max_question_num to 5 (#2520)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-24 09:28:27 +08:00
takatost
952e13fef8 Update README_CN.md (#2550) 2024-02-23 17:38:03 +08:00
Jyong
4be3087642 Fix/new RAG bugs (#2547)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-23 16:54:15 +08:00
Garfield Dai
49da8a23a8 feat: openai llm get trial or paid models from config. (#2546) 2024-02-23 16:48:58 +08:00
Garfield Dai
3ad943a9eb Feat/openai llm trial paid config (#2545) 2024-02-23 16:12:43 +08:00
zxhlyh
3082093293 fix: webapp name (#2543) 2024-02-23 14:54:03 +08:00
Jyong
b03bbab5ad fix dev/reformat (#2542)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-23 14:53:24 +08:00
crazywoola
9574730050 Feat/i18n restructure (#2529) 2024-02-23 14:31:06 +08:00
Jyong
91ea6fe4ee Fix/langchain document schema (#2539)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-23 14:16:44 +08:00
Joel
769be13189 chore: add api key and value placeholder (#2538) 2024-02-23 13:55:43 +08:00
Bowen Liang
e42175241e fix: tolerate exceptions in cleaning up index when vector db service unavailable (#2533) 2024-02-23 12:30:39 +08:00
Yeuoly
12257b438b Fix/tool default value (#2536) 2024-02-23 12:02:29 +08:00
Bowen Liang
9ecc736c30 fix: update current tenant id of account when switching tenant (#2530) 2024-02-23 10:51:19 +08:00
Jyong
6c4e6bf1d6 Feat/dify rag (#2528)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-22 23:31:57 +08:00
Jyong
97fe817186 Fix/upload limit (#2521)
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2024-02-22 17:16:22 +08:00
Charlie.Wei
52b12ed7eb Voice audition (#2504)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-22 16:06:17 +08:00
Yeuoly
d8ab4474b4 fix: bing search response filter (#2519) 2024-02-22 13:06:55 +08:00
crazywoola
1ecbd95adf Fix #2512 (#2515) 2024-02-22 09:22:57 +08:00
crazywoola
cad6e6624f fix: config not exists (#2513) 2024-02-21 19:27:38 +08:00
crazywoola
3505cbe05c update issue template (#2507) 2024-02-21 14:08:11 +08:00
Joel
e15359e589 fix: api doc example error (#2505) 2024-02-21 12:03:48 +08:00
Yeuoly
edb86f5f5a Feat/stream react (#2498) 2024-02-21 10:45:59 +08:00
Yash_1124
adf2651d1f FEAT: Add DuckDuckGo Search Tool for Enhanced Privacy-Focused Search Functionality (#2499) 2024-02-21 10:42:34 +08:00
Chenhe Gu
5031d64e28 Chore/delete chunk decode error alert (#2500) 2024-02-21 03:17:33 +08:00
Yeuoly
ae3ad59b16 Refactor agent history organization and initialization of agent scrat… (#2495) 2024-02-20 19:03:43 +08:00
Yeuoly
e6cd7b0467 feat: increase max tools (#2497) 2024-02-20 19:03:10 +08:00
crazywoola
97e9f52331 doc: typo in chat (#2492) 2024-02-20 16:08:01 +08:00
Yeuoly
25957d917a Add default values for optional parameters in API tool and parser (#2491) 2024-02-20 16:07:43 +08:00
Jyong
20b932da97 del doc support (#2494)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-20 16:05:09 +08:00
zxhlyh
207080babc fix: audio to text (#2493) 2024-02-20 15:16:46 +08:00
Yeuoly
48bacd01cc fix: incorrect tool name (#2489) 2024-02-20 14:50:57 +08:00
zxhlyh
297d0f1f30 fix: code-based extension (#2490) 2024-02-20 14:49:00 +08:00
zxhlyh
eedbe1b770 fix: chat restart (#2488) 2024-02-20 11:24:27 +08:00
kukuze
5ff6b1da07 Windows local deployment switch "tool“ interface failed (#2483) 2024-02-19 20:03:20 +08:00
takatost
8b49e0ee2a bump version to 0.5.6 (#2482) 2024-02-19 17:13:55 +08:00
crazywoola
e031ec9359 remove: parameters in seeds (#2481) 2024-02-19 17:00:46 +08:00
takatost
1bd1cd6938 fix: event handlers not registered globally (#2479) 2024-02-19 16:04:52 +08:00
Yash_1124
81c5a21b8d FEAT: add image styling in markdown (#2441)
Co-authored-by: crazywoola <427733928@qq.com>
2024-02-19 15:07:45 +08:00
Koen Farell
61e4bbabaf feat: added Ukrainian language support (#2473) 2024-02-19 13:11:23 +08:00
takatost
4cf475680d fix: credential verification of baichuan did not throw all errors (#2475) 2024-02-19 11:52:52 +08:00
Yeuoly
ca4aa340f6 fix: Add model_uid validation for model_uid in Xinference models (#2468) 2024-02-19 10:43:25 +08:00
Joel
767d8a4b05 fix: hybrid search may pass rerank enable false (#2467) 2024-02-18 17:52:05 +08:00
TseIan
0b8dcaba8f Chore: Add type files and unit test ci for Node.js SDK (#2268)
Co-authored-by: xieweicheng <xieweicheng@bytedance.com>
2024-02-18 15:54:14 +08:00
wjryours
af6a318aae fix: windows load provider file error (#2463) 2024-02-18 15:48:25 +08:00
Charlie.Wei
c6e2900be7 Display selected tts voice name (#2459)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-18 15:39:25 +08:00
crazywoola
963d9b6032 Feature/display selected info for tts (#2454) 2024-02-16 20:05:14 +08:00
johnpccd
b2ee738bb1 Ignore SSE comments to support openrouter streaming (#2432) 2024-02-16 10:00:10 +08:00
Charlie.Wei
c8ca3ff404 Tts add voice choose (#2453)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-16 01:10:11 +08:00
Charlie.Wei
5d8fa2c7af Tts add voice choose (#2452)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-16 00:15:22 +08:00
takatost
58df5e5376 fix: tts voice language to zh-Hans instead of zh-CN (#2450) 2024-02-16 00:05:29 +08:00
takatost
348ad1a624 Update pull_request_template.md (#2451) 2024-02-16 00:05:18 +08:00
takatost
73e17d5aa8 Create pull_request_template.md (#2449) 2024-02-15 23:35:59 +08:00
Charlie.Wei
300d9892a5 tts add voice choose (#2391)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-15 22:41:18 +08:00
Yeuoly
e47b5b43b8 fix: baichuan frequency_penalty (#2446) 2024-02-14 20:11:41 +08:00
takatost
21c9d9e200 feat: add introduction field in log detail response of chat app (#2445) 2024-02-14 12:38:13 +08:00
Igor Voloc
4f6916c4d8 Update SMTP environment variable name in docker-compose (#2444) 2024-02-14 12:29:27 +08:00
537 changed files with 14847 additions and 10136 deletions

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: input

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: input
attributes:

30
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,30 @@
# Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
## Type of Change
Please delete options that are not relevant.
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
# How Has This Been Tested?
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
- [ ] TODO
# Suggested Checklist:
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] My changes generate no new warnings
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
- [ ] `optional` I have made corresponding changes to the documentation
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
- [ ] `optional` New and existing unit tests pass locally with my changes

View File

@@ -41,6 +41,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup NodeJS
uses: actions/setup-node@v4
@@ -60,11 +62,10 @@ jobs:
yarn run lint
- name: Super-linter
uses: super-linter/super-linter/slim@v5
uses: super-linter/super-linter/slim@v6
env:
BASH_SEVERITY: warning
DEFAULT_BRANCH: main
ERROR_ON_MISSING_EXEC_BIT: true
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true

34
.github/workflows/tool-test-sdks.yaml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Run Unit Test For SDKs
on:
pull_request:
branches:
- main
jobs:
build:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [16, 18, 20]
defaults:
run:
working-directory: sdks/nodejs-client
steps:
- uses: actions/checkout@v4
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version }}
cache: ''
cache-dependency-path: 'yarn.lock'
- name: Install Dependencies
run: yarn install
- name: Test
run: yarn test

22
LICENSE
View File

@@ -1,24 +1,26 @@
# Dify Open Source License
# Open Source License
The Dify project is licensed under the Apache License 2.0, with the following additional conditions:
Dify is licensed under the Apache License 2.0, with the following additional conditions:
1. Dify is permitted to be used for commercialization, such as using Dify as a "backend-as-a-service" for your other applications, or delivering it to enterprises as an application development platform. However, when the following conditions are met, you must contact the producer to obtain a commercial license:
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify.AI source code to operate a multi-tenant SaaS service that is similar to the Dify.AI service edition.
b. LOGO and copyright information: In the process of using Dify, you may not remove or modify the LOGO or copyright information in the Dify console.
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
Please contact business@dify.ai by email to inquire about licensing matters.
2. As a contributor, you should agree that your contributed code:
2. As a contributor, you should agree that:
a. The producer can adjust the open-source agreement to be more strict or relaxed.
b. Can be used for commercial purposes, such as Dify's cloud business.
a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary.
b. Your contributed code may be used for commercial purposes, including but not limited to its cloud business operations.
Apart from this, all other rights and restrictions follow the Apache License 2.0. If you need more detailed information, you can refer to the full version of Apache License 2.0.
Apart from the specific conditions mentioned above, all other rights and restrictions follow the Apache License 2.0. Detailed information about the Apache License 2.0 can be found at http://www.apache.org/licenses/LICENSE-2.0.
The interactive design of this product is protected by appearance patent.
© 2023 LangGenius, Inc.
© 2024 LangGenius, Inc.
----------

View File

@@ -21,6 +21,17 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>
<p align="center">
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
</a>
<ul align="center" style="text-decoration: none; list-style: none;">
<li> US EST: 09:00 (9:00 AM)</li>
<li> CET: 15:00 (3:00 PM)</li>
<li> CST: 22:00 (10:00 PM)</li>
</ul>
</p>
<p align="center">
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs

View File

@@ -82,7 +82,7 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
MULTIMODAL_SEND_IMAGE_FORMAT=base64
# Mail configuration, support: resend, smtp
MAIL_TYPE=resend
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
@@ -130,3 +130,5 @@ UNSTRUCTURED_API_URL=
SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
BATCH_UPLOAD_LIMIT=10

View File

@@ -38,10 +38,11 @@ from extensions import (
from extensions.ext_database import db
from extensions.ext_login import login_manager
from libs.passport import PassportService
# DO NOT REMOVE BELOW
from services.account_service import AccountService
# DO NOT REMOVE BELOW
from events import event_handlers
from models import account, dataset, model, source, task, tool, tools, web
# DO NOT REMOVE ABOVE

View File

@@ -6,16 +6,16 @@ import click
from flask import current_app
from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
from models.account import Tenant
from models.dataset import Dataset
from models.model import Account
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, MessageAnnotation
from models.provider import Provider, ProviderModel
@@ -124,14 +124,124 @@ def reset_encrypt_key_pair():
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
"""
Migrate other vector database datas to Qdrant.
"""
click.echo(click.style('Start create qdrant indexes.', fg='green'))
create_count = 0
@click.command('vdb-migrate', help='migrate vector db.')
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
def vdb_migrate(scope: str):
if scope in ['knowledge', 'all']:
migrate_knowledge_vector_database()
if scope in ['annotation', 'all']:
migrate_annotation_vector_database()
def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style('Start migrate annotation data.', fg='green'))
create_count = 0
skipped_count = 0
total_count = 0
page = 1
while True:
try:
# get apps info
apps = db.session.query(App).filter(
App.status == 'normal'
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for app in apps:
total_count = total_count + 1
click.echo(f'Processing the {total_count} app {app.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
try:
click.echo('Create app annotation index: {}'.format(app.id))
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app.id
).first()
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo('App annotation setting is disabled: {}'.format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
).first()
if not dataset_collection_binding:
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
documents = []
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app.id,
"doc_id": annotation.id
}
)
documents.append(document)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index for app: {app.id}.',
fg='green'))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index for app {app.id}.',
fg='red'))
raise e
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
raise e
click.echo(f'Successfully migrated app annotation {app.id}.')
create_count += 1
except Exception as e:
click.echo(
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
fg='green'))
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0
skipped_count = 0
total_count = 0
config = current_app.config
vector_type = config.get('VECTOR_STORE')
page = 1
while True:
try:
@@ -140,60 +250,128 @@ def create_qdrant_indexes():
except NotFound:
break
model_manager = ModelManager()
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except Exception:
continue
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
total_count = total_count + 1
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
+ f'{create_count} created, ${skipped_count} skipped.')
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type:
skipped_count = skipped_count + 1
continue
collection_name = ''
if vector_type == "weaviate":
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "qdrant":
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"Start to migrate dataset {dataset.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
fg='green'))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
fg='red'))
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
segments_count = segments_count + 1
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
raise e
db.session.add(dataset)
db.session.commit()
click.echo(f'Successfully migrated dataset {dataset.id}.')
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(vdb_migrate)

View File

@@ -38,7 +38,9 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO',
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
@@ -56,6 +58,8 @@ DEFAULTS = {
'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False',
'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20
}
@@ -86,7 +90,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.5.5"
self.CURRENT_VERSION = "0.5.8"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -182,7 +186,7 @@ class Config:
# Currently, only support: qdrant, milvus, zilliz, weaviate
# ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
@@ -259,8 +263,10 @@ class Config:
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
@@ -285,6 +291,8 @@ class Config:
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
class CloudEditionConfig(Config):

View File

@@ -1,9 +1,8 @@
import json
from models.model import AppModelConfig
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
language_timezone_mapping = {
'en-US': 'America/New_York',
@@ -16,8 +15,10 @@ language_timezone_mapping = {
'ko-KR': 'Asia/Seoul',
'ru-RU': 'Europe/Moscow',
'it-IT': 'Europe/Rome',
'uk-UA': 'Europe/Kyiv',
}
def supported_language(lang):
if lang in languages:
return lang
@@ -26,6 +27,7 @@ def supported_language(lang):
.format(lang=lang))
raise ValueError(error)
user_input_form_template = {
"en-US": [
{
@@ -67,6 +69,16 @@ user_input_form_template = {
}
}
],
"ua-UK": [
{
"paragraph": {
"label": "Запит",
"variable": "default_input",
"required": False,
"default": ""
}
}
],
}
demo_model_templates = {
@@ -145,7 +157,7 @@ demo_model_templates = {
'Italian',
]
}
},{
}, {
"paragraph": {
"label": "Query",
"variable": "query",
@@ -272,7 +284,7 @@ demo_model_templates = {
"意大利语",
]
}
},{
}, {
"paragraph": {
"label": "文本内容",
"variable": "query",
@@ -323,5 +335,130 @@ demo_model_templates = {
)
}
],
'uk-UA': [{
"name": "Помічник перекладу",
"icon": "",
"icon_background": "",
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
"mode": "completion",
"model_config": AppModelConfig(
provider="openai",
model_id="gpt-3.5-turbo-instruct",
configs={
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
"prompt_variables": [
{
"key": "target_language",
"name": "Цільова мова",
"description": "Мова, на яку ви хочете перекласти.",
"type": "select",
"default": "Ukrainian",
"options": [
"Chinese",
"English",
"Japanese",
"French",
"Russian",
"German",
"Spanish",
"Korean",
"Italian",
],
},
],
"completion_params": {
"max_token": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
},
opening_statement="",
suggested_questions=None,
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
}),
user_input_form=json.dumps([
{
"select": {
"label": "Цільова мова",
"variable": "target_language",
"description": "Мова, на яку ви хочете перекласти.",
"default": "Chinese",
"required": True,
'options': [
'Chinese',
'English',
'Japanese',
'French',
'Russian',
'German',
'Spanish',
'Korean',
'Italian',
]
}
}, {
"paragraph": {
"label": "Запит",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
},
{
"name": "AI інтерв’юер фронтенду",
"icon": "",
"icon_background": "",
"description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.",
"mode": "chat",
"model_config": AppModelConfig(
provider="openai",
model_id="gpt-3.5-turbo",
configs={
"introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
"prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
"prompt_variables": [],
"completion_params": {
"max_token": 300,
"temperature": 0.8,
"top_p": 0.9,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
},
opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
suggested_questions=None,
pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
"top_p": 0.9,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
}),
user_input_form=None
),
}
],
}

View File

@@ -13,30 +13,14 @@ model_templates = {
'status': 'normal'
},
'model_config': {
'provider': 'openai',
'model_id': 'gpt-3.5-turbo-instruct',
'configs': {
'prompt_template': '',
'prompt_variables': [],
'completion_params': {
'max_token': 512,
'temperature': 1,
'top_p': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
}
},
'provider': '',
'model_id': '',
'configs': {},
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
"top_p": 1,
"presence_penalty": 0,
"frequency_penalty": 0
}
"completion_params": {}
}),
'user_input_form': json.dumps([
{
@@ -64,30 +48,14 @@ model_templates = {
'status': 'normal'
},
'model_config': {
'provider': 'openai',
'model_id': 'gpt-3.5-turbo',
'configs': {
'prompt_template': '',
'prompt_variables': [],
'completion_params': {
'max_token': 512,
'temperature': 1,
'top_p': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
}
},
'provider': '',
'model_id': '',
'configs': {},
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
"top_p": 1,
"presence_penalty": 0,
"frequency_penalty": 0
}
"completion_params": {}
})
}
},

View File

@@ -124,19 +124,13 @@ class AppListApi(Resource):
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
if provider_model not in available_models_names:
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
if not model_instance:
if not default_model_entity:
raise ProviderNotInitializeError(
"No Default System Reasoning Model available. Please configure "
"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model
model_config_dict["model"]["provider"] = default_model_entity.provider.provider
model_config_dict["model"]["name"] = default_model_entity.model
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,

View File

@@ -1,7 +1,7 @@
import logging
from flask import request
from flask_restful import Resource
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
@@ -45,7 +45,8 @@ class ChatMessageAudioApi(Resource):
try:
response = AudioService.transcript_asr(
tenant_id=app_model.tenant_id,
file=file
file=file,
end_user=None,
)
return response
@@ -71,7 +72,7 @@ class ChatMessageAudioApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
logging.exception(f"internal server error, {str(e)}.")
raise InternalServerError()
@@ -82,10 +83,12 @@ class ChatMessageTextApi(Resource):
def post(self, app_id):
app_id = str(app_id)
app_model = _get_app(app_id, None)
try:
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
@@ -112,9 +115,50 @@ class ChatMessageTextApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
logging.exception(f"internal server error, {str(e)}.")
raise InternalServerError()
class TextModesApi(Resource):
def get(self, app_id: str):
app_model = _get_app(str(app_id))
try:
parser = reqparse.RequestParser()
parser.add_argument('language', type=str, required=True, location='args')
args = parser.parse_args()
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args['language'],
)
return response
except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError:
raise AppUnavailableError("Text to audio voices language parameter loss.")
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception(f"internal server error, {str(e)}.")
raise InternalServerError()
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')

View File

@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required
@@ -173,14 +174,15 @@ class DataSourceNotionApi(Resource):
if not data_source_binding:
raise NotFound('Data source binding not found.')
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id
)
text_docs = loader.load()
text_docs = extractor.extract()
return {
'content': "\n".join([doc.page_content for doc in text_docs])
}, 200
@@ -192,11 +194,31 @@ class DataSourceNotionApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
notion_info_list = args['notion_info_list']
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'])
return response, 200

View File

@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
@@ -178,9 +179,9 @@ class DatasetApi(Resource):
location='json', store_missing=False,
type=_validate_description_length)
parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help='Invalid indexing technique.')
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True,
parser.add_argument('indexing_technique', type=str, required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args['info_list']['data_source_type'] == 'upload_file':
file_ids = args['info_list']['file_info_list']['file_ids']
file_details = db.session.query(UploadFile).filter(
@@ -278,37 +280,45 @@ class DatasetIndexingEstimateApi(Resource):
if file_details is None:
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
notion_info_list = args['info_list']['notion_info_list']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
return response, 200
@@ -508,4 +518,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')

View File

@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.document_fields import (
@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource):
req_data = request.args
document_id = req_data.get('document_id')
# get default rules
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource):
if not file:
raise NotFound('File not found.')
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file,
document_model=document.doc_form
)
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict, None,
'English', dataset_id)
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
data_process_rule_dict, document.doc_form,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
info_list = []
extract_settings = []
for document in documents:
if document.indexing_status in ['completed', 'error']:
raise DocumentAlreadyFinishedError()
@@ -424,42 +432,49 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
}
info_list.append(notion_info)
if dataset.data_source_type == 'upload_file':
file_details = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id.in_(info_list)
).all()
if document.data_source_type == 'upload_file':
file_id = data_source_info['upload_file_id']
file_detail = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id == file_id
).first()
if file_details is None:
raise NotFound("File not found.")
if file_detail is None:
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == 'notion_import':
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=document.doc_form
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict, None,
'English', dataset_id)
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
data_process_rule_dict, document.doc_form,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif dataset.data_source_type == 'notion_import':
indexing_runner = IndexingRunner()
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict,
None, 'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response

View File

@@ -11,7 +11,7 @@ from controllers.console.datasets.error import (
UnsupportedFileTypeError,
)
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService
@@ -39,6 +39,7 @@ class FileApi(Resource):
@login_required
@account_initialization_required
@marshal_with(file_fields)
@cloud_edition_billing_resource_check(resource='documents')
def post(self):
# get file from request

View File

@@ -85,6 +85,7 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
return {'data': response.data.decode('latin1')}

View File

@@ -259,6 +259,7 @@ class ToolApiProviderPreviousTestApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
@@ -268,6 +269,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args['provider_name'] if args['provider_name'] else '',
args['tool_name'],
args['credentials'],
args['parameters'],

View File

@@ -56,6 +56,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = features.members
apps = features.apps
vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota
annotation_quota_limit = features.annotation_quota_limit
if resource == 'members' and 0 < members.limit <= members.size:
@@ -64,6 +65,13 @@ def cloud_edition_billing_resource_check(resource: str,
abort(403, error_msg)
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
abort(403, error_msg)
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
source = request.args.get('source')
if source == 'datasets':
abort(403, error_msg)
else:
return view(*args, **kwargs)
elif resource == 'workspace_custom' and not features.can_replace_logo:
abort(403, error_msg)
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:

View File

@@ -1,27 +0,0 @@
from extensions.ext_database import db
from models.model import EndUser
def create_or_update_end_user_for_user_id(app_model, user_id):
"""
Create or update session terminal based on user ID.
"""
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user

View File

@@ -1,16 +1,16 @@
import json
from flask import current_app
from flask_restful import fields, marshal_with
from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import ApiToolProvider
class AppParameterApi(AppApiResource):
class AppParameterApi(Resource):
"""Resource for app variables."""
variable_fields = {
@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields)
}
@validate_app_token
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
def get(self, app_model: App):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
}
}
class AppMetaApi(AppApiResource):
def get(self, app_model: App, end_user):
class AppMetaApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config

View File

@@ -1,7 +1,7 @@
import logging
from flask import request
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig
from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -30,8 +30,9 @@ from services.errors.audio import (
)
class AudioApi(AppApiResource):
def post(self, app_model: App, end_user):
class AudioApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError()
class TextApi(AppApiResource):
def post(self, app_model: App, end_user):
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()
@@ -85,7 +86,8 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=args['text'],
end_user=args['user'],
end_user=end_user,
voice=args['voice'] if args['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming']
)

View File

@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App, EndUser
from services.completion_service import CompletionService
class CompletionApi(AppApiResource):
def post(self, app_model, end_user):
class CompletionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion':
raise AppUnavailableError()
@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
args['auto_generate_name'] = False
try:
@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError()
class CompletionStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class CompletionStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion':
raise AppUnavailableError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
class ChatApi(AppApiResource):
def post(self, app_model, end_user):
class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
response = CompletionService.completion(
app_model=app_model,
@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError()
class ChatStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat':
raise NotChatAppError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200

View File

@@ -1,52 +1,44 @@
from flask import request
from flask_restful import marshal_with, reqparse
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import App, EndUser
from services.conversation_service import ConversationService
class ConversationApi(AppApiResource):
class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource):
class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id):
def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
user = request.get_json().get('user')
if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
try:
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource):
class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.rename(
app_model,

View File

@@ -1,30 +1,27 @@
from flask import request
from flask_restful import marshal_with
from flask_restful import Resource, marshal_with
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService
class FileApi(AppApiResource):
class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
def post(self, app_model, end_user):
def post(self, app_model: App, end_user: EndUser):
file = request.files['file']
user_args = request.form.get('user')
if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file
if 'file' not in request.files:

View File

@@ -1,20 +1,18 @@
from flask_restful import fields, marshal_with, reqparse
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from extensions.ext_database import db
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from models.model import EndUser, Message
from models.model import App, EndUser
from services.message_service import MessageService
class MessageListApi(AppApiResource):
class MessageListApi(Resource):
feedback_fields = {
'rating': fields.String
}
@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields))
}
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit'])
@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(AppApiResource):
def post(self, app_model, end_user, message_id):
class MessageFeedbackApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument('user', type=str, location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError:
@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}
class MessageSuggestedApi(AppApiResource):
def get(self, app_model, end_user, message_id):
class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()
try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=user,
user=end_user,
message_id=message_id,
check_enabled=False
)

View File

@@ -28,6 +28,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
@cloud_edition_billing_resource_check('documents', 'dataset')
def post(self, tenant_id, dataset_id):
"""Create document by text."""
parser = reqparse.RequestParser()
@@ -153,6 +154,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
@cloud_edition_billing_resource_check('documents', 'dataset')
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}

View File

@@ -1,22 +1,40 @@
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in
from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin
from models.model import ApiToken, App
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
def validate_app_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
class WhereisUserArg(Enum):
"""
Enum for whereis_user_arg.
"""
QUERY = 'query'
JSON = 'json'
FORM = 'form'
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
@@ -29,16 +47,35 @@ def validate_app_token(view=None):
if not app_model.enable_api:
raise NotFound()
return view(app_model, None, *args, **kwargs)
return decorated
kwargs['app_model'] = app_model
if view:
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get('user')
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")
if user_id:
user_id = str(user_id)
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
def cloud_edition_billing_resource_check(resource: str,
api_token_type: str,
@@ -52,6 +89,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = features.members
apps = features.apps
vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota
if resource == 'members' and 0 < members.limit <= members.size:
raise Unauthorized(error_msg)
@@ -59,6 +97,8 @@ def cloud_edition_billing_resource_check(resource: str,
raise Unauthorized(error_msg)
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
raise Unauthorized(error_msg)
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Unauthorized(error_msg)
else:
return view(*args, **kwargs)
@@ -128,8 +168,33 @@ def validate_and_get_api_token(scope=None):
return api_token
class AppApiResource(Resource):
method_decorators = [validate_app_token]
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
"""
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = 'DEFAULT-USER'
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource):

View File

@@ -68,17 +68,23 @@ class AudioApi(WebApiResource):
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
logging.exception(f"internal server error: {str(e)}")
raise InternalServerError()
class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.text_to_speech_dict['enabled']:
raise AppUnavailableError()
try:
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
end_user=end_user.external_user_id,
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
@@ -105,7 +111,7 @@ class TextApi(WebApiResource):
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
logging.exception(f"internal server error: {str(e)}")
raise InternalServerError()

View File

@@ -1,49 +0,0 @@
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param model_config:
:param messages:
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass

View File

@@ -1,361 +0,0 @@
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_config: ModelConfigEntity = None
model_config: ModelConfigEntity
agent_llm_callback: Optional[AgentLLMCallback] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = 0
for parameter_rule in self.model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
self.model_config.parameters['max_tokens'] = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
except Exception as e:
raise e
self.model_config.parameters['max_tokens'] = original_max_tokens
return True if result.message.tool_calls else False
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
# summarize messages if rest_tokens < 0
try:
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.message.content or "",
additional_kwargs={
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
agent_decision = _parse_ai_message(ai_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
@classmethod
def get_system_message(cls):
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
try:
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(
self.model_config,
messages,
**kwargs
)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def predict_new_summary(
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_config.provider == 'azure_openai':
model = model_config.model
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_config.credentials.get("base_model_name")
tiktoken_ = _import_tiktoken()
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -1,306 +0,0 @@
import re
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
HumanMessage,
OutputParserException,
get_buffer_string,
)
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_config: ModelConfigEntity = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observatons
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = []
if prompts:
messages = prompts[0].to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
raise e
try:
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
self.moving_summary_index = len(intermediate_steps)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
if 'chat_history' in kwargs:
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)

View File

@@ -84,7 +84,7 @@ class AppRunner:
return rest_tokens
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
def recalc_llm_max_tokens(self, model_config: ModelConfigEntity,
prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance

View File

@@ -1,4 +1,3 @@
import json
import logging
from typing import cast
@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
message_chain = self._init_message_chain(
message=message,
query=query
)
# init model instance
model_instance = ModelInstance(
@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
'pool': db_variables.variables
})
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""

View File

@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
@@ -181,7 +181,7 @@ class BasicApplicationRunner(AppRunner):
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
self.recalc_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)

View File

@@ -175,7 +175,7 @@ class GenerateTaskPipeline:
'id': self._message.id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'answer': event.llm_result.message.content,
'answer': self._task_state.llm_result.message.content,
'metadata': {},
'created_at': int(self._message.created_at.timestamp())
}

View File

@@ -28,6 +28,7 @@ from core.entities.application_entities import (
ModelConfigEntity,
PromptTemplateEntity,
SensitiveWordAvoidanceEntity,
TextToSpeechEntity,
)
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
@@ -572,7 +573,11 @@ class ApplicationManager:
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
if text_to_speech_dict:
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
properties['text_to_speech'] = True
properties['text_to_speech'] = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
)
# sensitive word avoidance
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')

View File

@@ -1,8 +1,7 @@
from langchain.schema import Document
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import DatasetQuery, DocumentSegment
from models.model import DatasetRetrieverResource

View File

@@ -1,107 +0,0 @@
import tempfile
from pathlib import Path
from typing import Optional, Union
import requests
from flask import current_app
from langchain.document_loaders import Docx2txtLoader, TextLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
return cls.load_from_file(file_path, return_text)
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[list[Document], str]:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
else MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx', '.doc']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
elif file_extension == '.eml':
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
elif file_extension == '.ppt':
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
elif file_extension == '.pptx':
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
elif file_extension == '.xml':
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
else TextLoader(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx', '.doc']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

View File

@@ -1,55 +0,0 @@
import logging
from typing import Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents

View File

@@ -1,12 +1,12 @@
from collections.abc import Sequence
from typing import Any, Optional, cast
from langchain.schema import Document
from sqlalchemy import func
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment

View File

@@ -3,12 +3,12 @@ import logging
from typing import Optional, cast
import numpy as np
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.entity.embedding import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper

View File

@@ -0,0 +1,8 @@
from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'

View File

@@ -42,6 +42,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Advanced Completion Prompt Template Entity.
"""
class RolePrefixEntity(BaseModel):
"""
Role Prefix Entity.
@@ -57,6 +58,7 @@ class PromptTemplateEntity(BaseModel):
"""
Prompt Template Entity.
"""
class PromptType(Enum):
"""
Prompt Type.
@@ -97,6 +99,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
"""
Dataset Retrieve Strategy.
@@ -143,6 +146,15 @@ class SensitiveWordAvoidanceEntity(BaseModel):
config: dict[str, Any] = {}
class TextToSpeechEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
class FileUploadEntity(BaseModel):
"""
File Upload Entity.
@@ -159,6 +171,7 @@ class AgentToolEntity(BaseModel):
tool_name: str
tool_parameters: dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
@@ -166,6 +179,7 @@ class AgentPromptEntity(BaseModel):
first_prompt: str
next_iteration: str
class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
@@ -182,12 +196,14 @@ class AgentScratchpadUnit(BaseModel):
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
action: Optional[Action] = None
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
@@ -202,6 +218,7 @@ class AgentEntity(BaseModel):
tools: list[AgentToolEntity] = None
max_iteration: int = 5
class AppOrchestrationConfigEntity(BaseModel):
"""
App Orchestration Config Entity.
@@ -219,7 +236,7 @@ class AppOrchestrationConfigEntity(BaseModel):
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
text_to_speech: bool = False
text_to_speech: dict = {}
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None

View File

@@ -1,199 +0,0 @@
import logging
from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import (
AgentEntity,
AppOrchestrationConfigEntity,
InvokeFrom,
ModelConfigEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import Message
logger = logging.getLogger(__name__)
class AgentRunnerFeature:
def __init__(self, tenant_id: str,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
agent_llm_callback: AgentLLMCallback,
callback: AgentLoopGatherCallbackHandler,
memory: Optional[TokenBufferMemory] = None,) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.agent_llm_callback = agent_llm_callback
self.callback = callback
self.memory = memory
def run(self, query: str,
invoke_from: InvokeFrom) -> Optional[str]:
"""
Retrieve agent loop result.
:param query: query
:param invoke_from: invoke from
:return:
"""
provider = self.config.provider
model = self.config.model
tool_configs = self.config.tools
# check model is support tool calling
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model,
credentials=self.model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.FUNCTION_CALL
tools = self.to_tools(
tool_configs=tool_configs,
invoke_from=invoke_from,
callbacks=[self.callback, DifyStdOutCallbackHandler()],
)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=self.model_config,
tools=tools,
memory=self.memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate",
agent_llm_callback=self.agent_llm_callback,
callbacks=[self.callback, DifyStdOutCallbackHandler()]
)
agent_executor = AgentExecutor(agent_configuration)
try:
# check if should use agent
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
except Exception as ex:
logger.exception("agent_executor run failed")
return None
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config: tool config
:param invoke_from: invoke from
"""
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id,
invoke_from=invoke_from
)
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
# pass if dataset is not available
if not dataset:
return None
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
return None
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=show_retrieve_source,
retriever_from=invoke_from.to_source()
)
return tool

View File

@@ -1,13 +1,8 @@
import logging
from typing import Optional
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.entities.application_entities import InvokeFrom
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
embedding_provider_name = collection_binding_detail.provider_name
embedding_model_name = collection_binding_detail.model_name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_record.tenant_id,
provider=embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name
)
# get embedding model
embeddings = CacheEmbedding(model_instance)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
documents = vector_index.search(
documents = vector.search_by_vector(
query=query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
top_k=1,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
)

View File

@@ -1,5 +1,6 @@
import json
import logging
import uuid
from datetime import datetime
from mimetypes import guess_extension
from typing import Optional, Union, cast
@@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner):
self.message = message
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = prompt_messages
self.history_prompt_messages = self.organize_agent_history(
prompt_messages=prompt_messages or []
)
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
@@ -504,17 +514,6 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit()
def get_history_prompt_messages(self) -> list[PromptMessage]:
"""
Get history prompt messages
"""
if self.history_prompt_messages is None:
self.history_prompt_messages = db.session.query(PromptMessage).filter(
PromptMessage.message_id == self.message.id,
).order_by(PromptMessage.position.asc()).all()
return self.history_prompt_messages
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
"""
@@ -589,4 +588,60 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.commit()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize agent history
"""
result = []
# check if there is a system message in the beginning of the conversation
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
result.append(prompt_messages[0])
messages: list[Message] = db.session.query(Message).filter(
Message.conversation_id == self.message.conversation_id,
).order_by(Message.created_at.asc()).all()
for message in messages:
result.append(UserPromptMessage(content=message.query))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(';')
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
tool_inputs = json.loads(agent_thought.tool_input)
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(AssistantPromptMessage.ToolCall(
id=tool_call_id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
name=tool,
tool_call_id=tool_call_id,
))
result.extend([
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response
])
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
return result

View File

@@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -39,6 +40,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
self._repack_app_orchestration_config(app_orchestration_config)
agent_scratchpad: list[AgentScratchpadUnit] = []
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
# check model mode
if self.app_orchestration_config.model_config.mode == "completion":
@@ -128,64 +130,98 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
input=query
)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# recalc llm max tokens
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
llm_result: LLMResult = model_instance.invoke_llm(
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
tools=[],
stop=app_orchestration_config.model_config.stop,
stream=False,
stream=True,
user=self.user_id,
callbacks=[],
)
# check llm result
if not llm_result:
if not chunks:
raise ValueError("failed to invoke llm")
# get scratchpad
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
agent_scratchpad.append(scratchpad)
# get llm usage
if llm_result.usage:
increase_usage(llm_usage, llm_result.usage)
usage_dict = {}
react_chunks = self._handle_stream_react(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
action_str='',
observation='',
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
for chunk in react_chunks:
if isinstance(chunk, dict):
scratchpad.agent_response += json.dumps(chunk)
try:
if scratchpad.action:
raise Exception("")
scratchpad.action_str = json.dumps(chunk)
scratchpad.action = AgentScratchpadUnit.Action(
action_name=chunk['action'],
action_input=chunk['action_input']
)
except:
scratchpad.thought += json.dumps(chunk)
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=json.dumps(chunk)
),
usage=None
)
)
else:
scratchpad.agent_response += chunk
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
)
agent_scratchpad.append(scratchpad)
# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
else:
usage_dict['usage'] = LLMUsage.empty_usage()
self.save_agent_thought(agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input=scratchpad.action.action_input if scratchpad.action else '',
thought=scratchpad.thought,
observation='',
answer=llm_result.message.content,
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=llm_result.usage)
llm_usage=usage_dict['usage'])
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# publish agent thought if it's not empty and there is a action
if scratchpad.thought and scratchpad.action:
# check if final answer
if not scratchpad.action.action_name.lower() == "final answer":
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=scratchpad.thought
),
usage=llm_result.usage,
),
system_fingerprint=''
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = scratchpad.agent_response or ''
@@ -260,7 +296,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# save scratchpad
scratchpad.observation = observation
scratchpad.agent_response = llm_result.message.content
# save agent thought
self.save_agent_thought(
@@ -269,7 +304,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
tool_input=tool_call_args,
thought=None,
observation=observation,
answer=llm_result.message.content,
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
@@ -316,6 +351,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
system_fingerprint=''
), PublishFrom.APPLICATION_MANAGER)
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
-> Generator[Union[str, dict], None, None]:
def parse_json(json_str):
try:
return json.loads(json_str.strip())
except:
return json_str
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
yield parse_json(json_text)
code_block_cache = ''
code_block_delimiter_count = 0
in_code_block = False
json_cache = ''
json_quote_count = 0
in_json = False
got_json = False
for response in llm_response:
response = response.delta.message.content
if not isinstance(response, str):
continue
# stream
index = 0
while index < len(response):
steps = 1
delta = response[index:index+steps]
if delta == '`':
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ''
else:
code_block_cache += delta
code_block_delimiter_count = 0
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ''
in_code_block = not in_code_block
code_block_delimiter_count = 0
if not in_code_block:
# handle single json
if delta == '{':
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == '}':
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
if json_quote_count == 0:
in_json = False
got_json = True
index += steps
continue
else:
if in_json:
json_cache += delta
if got_json:
got_json = False
yield parse_json(json_cache)
json_cache = ''
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace('`', '')
index += steps
if code_block_cache:
yield code_block_cache
if json_cache:
yield parse_json(json_cache)
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
fill in inputs from external data tools
@@ -327,122 +453,40 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
continue
return instruction
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
def _init_agent_scratchpad(self,
agent_scratchpad: list[AgentScratchpadUnit],
messages: list[PromptMessage]
) -> list[AgentScratchpadUnit]:
"""
extract response from llm response
init agent scratchpad
"""
def extra_quotes() -> AgentScratchpadUnit:
agent_response = content
# try to extract all quotes
pattern = re.compile(r'```(.*?)```', re.DOTALL)
quotes = pattern.findall(content)
# try to extract action from end to start
for i in range(len(quotes) - 1, 0, -1):
"""
1. use json load to parse action
2. use plain text `Action: xxx` to parse action
"""
try:
action = json.loads(quotes[i].replace('```', ''))
action_name = action.get("action")
action_input = action.get("action_input")
agent_thought = agent_response.replace(quotes[i], '')
if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=quotes[i],
action=AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
current_scratchpad: AgentScratchpadUnit = None
for message in messages:
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content,
action_str='',
action=None,
observation=None,
)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments)
)
except:
# try to parse action from plain text
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
# delete action from agent response
agent_thought = agent_response.replace(quotes[i], '')
# remove extra quotes
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=quotes[i],
action=AgentScratchpadUnit.Action(
action_name=action_name[0],
action_input=action_input[0],
)
)
def extra_json():
agent_response = content
# try to extract all json
structures, pair_match_stack = [], []
started_at, end_at = 0, 0
for i in range(len(content)):
if content[i] == '{':
pair_match_stack.append(i)
if len(pair_match_stack) == 1:
started_at = i
elif content[i] == '}':
begin = pair_match_stack.pop()
if not pair_match_stack:
end_at = i + 1
structures.append((content[begin:i+1], (started_at, end_at)))
# handle the last character
if pair_match_stack:
end_at = len(content)
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
for i in range(len(structures), 0, -1):
try:
json_content, (started_at, end_at) = structures[i - 1]
action = json.loads(json_content)
action_name = action.get("action")
action_input = action.get("action_input")
# delete json content from agent response
agent_thought = agent_response[:started_at] + agent_response[end_at:]
# remove extra quotes like ```(json)*\n\n```
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input is not None:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=json_content,
action=AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
)
except:
pass
agent_scratchpad = extra_quotes()
if agent_scratchpad:
return agent_scratchpad
agent_scratchpad = extra_json()
if agent_scratchpad:
return agent_scratchpad
return AgentScratchpadUnit(
agent_response=content,
thought=content,
action_str='',
action=None
)
except:
pass
agent_scratchpad.append(current_scratchpad)
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
current_scratchpad.observation = message.content
return agent_scratchpad
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
agent_prompt_message: AgentPromptEntity,
):
@@ -556,15 +600,22 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# organize prompt messages
if mode == "chat":
# override system message
overrided = False
overridden = False
prompt_messages = prompt_messages.copy()
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message.content = system_message
overrided = True
overridden = True
break
# convert tool prompt messages to user prompt messages
for idx, prompt_message in enumerate(prompt_messages):
if isinstance(prompt_message, ToolPromptMessage):
prompt_messages[idx] = UserPromptMessage(
content=prompt_message.content
)
if not overrided:
if not overridden:
prompt_messages.insert(0, SystemPromptMessage(
content=system_message,
))

View File

@@ -105,8 +105,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
messages_ids=message_file_ids
)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# recalc llm max tokens
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,

View File

@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):

View File

@@ -12,9 +12,9 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):

View File

@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.

View File

@@ -1,4 +1,3 @@
import enum
import logging
from typing import Optional, Union
@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigEntity
@@ -62,28 +53,7 @@ class AgentExecutor:
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
if self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]

View File

@@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

View File

@@ -104,37 +104,17 @@ class HostingConfiguration:
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
restrict_models=trial_models
)
quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(
restrict_models=[
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
restrict_models=paid_models
)
quotas.append(paid_quota)
@@ -258,3 +238,11 @@ class HostingConfiguration:
return HostedModerationConfig(
enabled=False
)
@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
models_list = models_str.split(",") if models_str else []
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
model_name.strip()]

View File

@@ -1,51 +0,0 @@
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)

View File

@@ -1,305 +0,0 @@
import json
import logging
from abc import abstractmethod
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores import VectorStore
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset: Dataset) -> str:
raise NotImplementedError
@abstractmethod
def to_index_struct(self) -> dict:
raise NotImplementedError
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
@abstractmethod
def search_by_full_text_index(
self, query: str,
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if self.dataset.collection_binding_id:
vector_store.delete_by_group_id(group_id)
else:
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.info(f"Recreating dataset {dataset.id}")
try:
self.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
origin_index_struct = self.dataset.index_struct[:]
self.dataset.index_struct = None
if documents:
try:
self.create(documents)
except Exception as e:
self.dataset.index_struct = origin_index_struct
raise e
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
def create_qdrant_dataset(self, dataset: Dataset):
logging.info(f"create_qdrant_dataset {dataset.id}")
try:
self.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def update_qdrant_dataset(self, dataset: Dataset):
logging.info(f"update_qdrant_dataset {dataset.id}")
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).first()
if segment:
try:
exist = self.text_exists(segment.index_node_id)
if exist:
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.add_texts(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")
self.delete()
dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@@ -1,165 +0,0 @@
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from models.dataset import Dataset
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'milvus'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
connection_args=self._client_config.to_milvus_params(),
index_params=index_params
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content'
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
embedding_function=self._embeddings,
connection_args=self._client_config.to_milvus_params()
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_document_id(document_id)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_doc_ids(doc_ids)
vector_store.del_texts({
'filter': f' id in {ids}'
})
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []

View File

@@ -1,229 +0,0 @@
import os
from typing import Any, Optional, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
for node_id in ids:
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
return vector_store.similarity_search_by_bm25(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
],
), kwargs.get('top_k', 2))

View File

@@ -1,90 +0,0 @@
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
self._attributes = attributes
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError("Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings,
attributes=attributes
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
),
embeddings=embeddings
)
elif vector_type == "milvus":
from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex
return MilvusVectorIndex(
dataset=dataset,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")

View File

@@ -1,179 +0,0 @@
from typing import Any, Optional, cast
import requests
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = self._attributes
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)

View File

@@ -9,21 +9,21 @@ from typing import Optional, cast
from flask import Flask, current_app
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import TextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatasetDocumentStore
from core.errors.error import ProviderTokenNotInitError
from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from core.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@@ -31,7 +31,7 @@ from libs import helper
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from models.source import DataSourceBinding
from services.feature_service import FeatureService
class IndexingRunner:
@@ -56,38 +56,19 @@ class IndexingRunner:
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# load file
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
# transform
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
# save segment
self._load_segments(dataset, dataset_document, documents)
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
self._build_index(
# load
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
documents=documents
@@ -133,39 +114,19 @@ class IndexingRunner:
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# load file
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
# transform
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
# save segment
self._load_segments(dataset, dataset_document, documents)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
# load
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
documents=documents
@@ -219,7 +180,15 @@ class IndexingRunner:
documents.append(document)
# build index
self._build_index(
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=dataset_document,
documents=documents
@@ -238,12 +207,20 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
# check document limit
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
count = len(extract_settings)
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(
@@ -275,16 +252,18 @@ class IndexingRunner:
total_segments = 0
total_price = 0
currency = 'USD'
for file_detail in file_details:
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
for extract_setting in extract_settings:
# extract
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
all_text_docs.extend(text_docs)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# load data from file
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
@@ -296,7 +275,6 @@ class IndexingRunner:
)
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
@@ -355,146 +333,8 @@ class IndexingRunner:
"preview": preview_texts
}
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == 'high_quality':
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
# load data from notion
tokens = 0
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page['page_id'],
notion_page_type=page['type']
)
documents = loader.load()
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._split_to_documents_for_estimate(
text_docs=documents,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(documents)
embedding_model_type_instance = None
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' and embedding_model_type_instance:
tokens += embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
texts=[document.page_content]
)
if doc_form and doc_form == 'qa_model':
model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(price_info.total_amount),
"currency": price_info.currency,
"qa_preview": document_qa_list,
"preview": preview_texts
}
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
-> list[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return []
@@ -510,11 +350,28 @@ class IndexingRunner:
one_or_none()
if file_detail:
text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
if (not data_source_info or 'notion_workspace_id' not in data_source_info
or 'notion_page_id' not in data_source_info):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type'],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id
},
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,
@@ -528,8 +385,6 @@ class IndexingRunner:
# replace doc id to document model id
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs:
# remove invalid symbol
text_doc.page_content = self.filter_string(text_doc.page_content)
text_doc.metadata['document_id'] = dataset_document.id
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
@@ -770,12 +625,12 @@ class IndexingRunner:
for q, a in matches if q and a
]
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
dataset_document: DatasetDocument, documents: list[Document]) -> None:
"""
Build the index for the document.
insert index and update document/segment status to completed
"""
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
embedding_model_instance = self.model_manager.get_model_instance(
@@ -808,13 +663,9 @@ class IndexingRunner:
)
for document in chunk_documents
)
# save vector index
if vector_index:
vector_index.add_texts(chunk_documents)
# save keyword index
keyword_table_index.add_texts(chunk_documents)
# load index
index_processor.load(dataset, chunk_documents)
db.session.add(dataset)
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
@@ -894,14 +745,64 @@ class IndexingRunner:
)
documents.append(document)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents, duplicate_check=True)
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
text_docs: list[Document], process_rule: dict) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
process_rule=process_rule)
return documents
def _load_segments(self, dataset, dataset_document, documents):
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset,
user_id=dataset_document.created_by,
document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.utcnow()
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
}
)
# update segment status to indexing
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
)
pass
class DocumentIsPausedException(Exception):

View File

@@ -99,7 +99,8 @@ class ModelInstance:
user=user
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
@@ -166,13 +167,15 @@ class ModelInstance:
user=user
)
def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
-> str:
"""
Invoke large language model
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param user: unique user id
:param voice: model timbre
:param streaming: output is streaming
:return: text for given audio file
"""
@@ -185,9 +188,28 @@ class ModelInstance:
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
streaming=streaming
)
def get_tts_voices(self, language: str) -> list:
"""
Invoke large language tts model voices
:param language: tts language
:return: tts model voices
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model,
credentials=self.credentials,
language=language
)
class ModelManager:
def __init__(self) -> None:

View File

@@ -20,7 +20,7 @@
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md)。
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
- 可选择的模型列表展示
@@ -86,4 +86,4 @@ Model Runtime 分三层:
![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。

View File

@@ -48,6 +48,10 @@
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
- `default_voice` (string) default voice, e.g.alloy,echo,fable,onyx,nova,shimmeravailable for model type `tts`
- `voices` (list) List of available voice.available for model type `tts`
- `mode` (string) voice model.available for model type `tts`
- `name` (string) voice model display name.available for model type `tts`
- `lanuage` (string) the voice model supports languages.available for model type `tts`
- `word_limit` (int) Single conversion word limit, paragraphwise by defaultavailable for model type `tts`
- `audio_type` (string) Support audio file extension format, e.g.mp3,wavavailable for model type `tts`
- `max_workers` (int) Number of concurrent workers supporting text and audio conversionavailable for model type`tts`

View File

@@ -48,7 +48,11 @@
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
- `file_upload_limit` (int) 文件最大上传限制单位MB。模型类型 `speech2text` 可用)
- `supported_file_extensions` (string) 支持文件扩展格式mp3,mp4模型类型 `speech2text` 可用)
- `default_voice` (string) 缺省音色,alloy,echo,fable,onyx,nova,shimmer模型类型 `tts` 可用)
- `default_voice` (string) 缺省音色,alloy,echo,fable,onyx,nova,shimmer模型类型 `tts` 可用)
- `voices` (list) 可选音色列表。
- `mode` (string) 音色模型。(模型类型 `tts` 可用)
- `name` (string) 音色模型显示名称。(模型类型 `tts` 可用)
- `lanuage` (string) 音色模型支持语言。(模型类型 `tts` 可用)
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
- `audio_type` (string) 支持音频文件扩展格式mp3,wav模型类型 `tts` 可用)
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)

View File

@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
'min': 1,
'max': 2048,
'precision': 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式确保llm的输出尽可能是有效的代码块如JSON、XML等',
},
'required': False,
'options': ['JSON', 'XML'],
}
}

View File

@@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
@@ -127,6 +128,7 @@ class ModelPropertyKey(Enum):
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
DEFAULT_VOICE = "default_voice"
VOICES = "voices"
WORD_LIMIT = "word_limit"
AUDOI_TYPE = "audio_type"
MAX_WORKERS = "max_workers"

View File

@@ -262,23 +262,23 @@ class AIModel(ABC):
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max:
if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min:
if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.precision:
if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision:
if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required:
if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help:
if not parameter_rule.help and 'help' in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'],
)
if not parameter_rule.help.en_US:
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans:
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
except ValueError:
pass

View File

@@ -9,7 +9,13 @@ from typing import Optional, Union
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.callbacks.logging_callback import LoggingCallback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
ModelPropertyKey,
ModelType,
@@ -74,7 +80,20 @@ class LargeLanguageModel(AIModel):
)
try:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
if "response_format" in model_parameters:
result = self._code_block_mode_wrapper(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
else:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
except Exception as e:
self._trigger_invoke_error_callbacks(
model=model,
@@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel):
return result
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
if isinstance(response, Generator):
first_chunk = next(response)
def new_generator():
yield first_chunk
yield from response
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
)
else:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
)
return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote
:param model: model name
:param prompt_messages: prompt messages
:param input_generator: input generator
:return: output generator
"""
state = "normal"
backtick_count = 0
for piece in input_generator:
if piece.delta.message.content:
content = piece.delta.message.content
piece.delta.message.content = ""
yield piece
piece = content
else:
yield piece
continue
new_piece = ""
for char in piece:
if state == "normal":
if char == "`":
state = "in_backticks"
backtick_count = 1
else:
new_piece += char
elif state == "in_backticks":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "skip_content"
backtick_count = 0
else:
new_piece += "`" * backtick_count + char
state = "normal"
backtick_count = 0
elif state == "skip_content":
if char.isspace():
state = "normal"
if new_piece:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
),
)
)
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote.
This version skips the language identifier that follows the opening triple backticks.
:param model: model name
:param prompt_messages: prompt messages
:param input_generator: input generator
:return: output generator
"""
state = "search_start"
backtick_count = 0
for piece in input_generator:
if piece.delta.message.content:
content = piece.delta.message.content
# Reset content to ensure we're only processing and yielding the relevant parts
piece.delta.message.content = ""
# Yield a piece with cleared content before processing it to maintain the generator structure
yield piece
piece = content
else:
# Yield pieces without content directly
yield piece
continue
if state == "done":
continue
new_piece = ""
for char in piece:
if state == "search_start":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "skip_language"
backtick_count = 0
else:
backtick_count = 0
elif state == "skip_language":
# Skip everything until the first newline, marking the end of the language identifier
if char == "\n":
state = "in_code_block"
elif state == "in_code_block":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "done"
break
else:
if backtick_count > 0:
# If backticks were counted but we're still collecting content, it was a false start
new_piece += "`" * backtick_count
backtick_count = 0
new_piece += char
elif state == "done":
break
if new_piece:
# Only yield content collected within the code block
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
),
)
)
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
@@ -204,7 +456,7 @@ class LargeLanguageModel(AIModel):
:return: full response or stream response chunk generator result
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:

View File

@@ -15,29 +15,37 @@ class TTSModel(AIModel):
"""
model_type: ModelType = ModelType.TTS
def invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
:return: translated audio file
"""
try:
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, content_text=content_text)
self._is_ffmpeg_installed()
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
content_text=content_text, voice=voice, tenant_id=tenant_id)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
user: Optional[str] = None):
"""
Invoke large language model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id
@@ -45,7 +53,25 @@ class TTSModel(AIModel):
"""
raise NotImplementedError
def _get_model_voice(self, model: str, credentials: dict) -> any:
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
"""
Get voice for given tts model voices
:param language: tts language
:param model: model name
:param credentials: model credentials
:return: voices lists
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
if language:
return [{'name': d['name'], 'value': d['mode']} for d in voices if language and language in d.get('language')]
else:
return [{'name': d['name'], 'value': d['mode']} for d in voices]
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
"""
Get voice for given tts model

View File

@@ -6,6 +6,7 @@
- bedrock
- togetherai
- ollama
- mistralai
- replicate
- huggingface_hub
- zhipuai

View File

@@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider):
# Use `claude-instant-1` model for validate,
model_instance.validate_credentials(
model='claude-instant-1',
model='claude-instant-1.2',
credentials=credentials
)
except CredentialsValidateFailedError as ex:

View File

@@ -2,8 +2,8 @@ provider: anthropic
label:
en_US: Anthropic
description:
en_US: Anthropics powerful models, such as Claude 2 and Claude Instant.
zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant
en_US: Anthropics powerful models, such as Claude 3.
zh_Hans: Anthropic 的强大模型,例如 Claude 3
icon_small:
en_US: icon_s_en.svg
icon_large:

View File

@@ -0,0 +1,6 @@
- claude-3-opus-20240229
- claude-3-sonnet-20240229
- claude-2.1
- claude-instant-1.2
- claude-2
- claude-instant-1

View File

@@ -27,6 +27,8 @@ parameter_rules:
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '8.00'
output: '24.00'

View File

@@ -27,8 +27,11 @@ parameter_rules:
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '8.00'
output: '24.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@@ -0,0 +1,37 @@
model: claude-3-opus-20240229
label:
en_US: claude-3-opus-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '15.00'
output: '75.00'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,37 @@
model: claude-3-sonnet-20240229
label:
en_US: claude-3-sonnet-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,35 @@
model: claude-instant-1.2
label:
en_US: claude-instant-1.2
model_type: llm
features: [ ]
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD

View File

@@ -26,8 +26,11 @@ parameter_rules:
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD
deprecated: true

View File

@@ -1,17 +1,32 @@
import base64
import mimetypes
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
import anthropic
import requests
from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
completion_create_params,
)
from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
@@ -25,9 +40,17 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
@@ -47,7 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# transform model parameters from completion api of anthropic to chat api
if 'max_tokens_to_sample' in model_parameters:
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
# init model client
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
# chat model
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or []
# chat model
self._transform_chat_json_prompts(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@@ -74,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return:
"""
try:
self._generate(
self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=[
@@ -82,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
],
model_parameters={
"temperature": 0,
"max_tokens_to_sample": 20,
"max_tokens": 20,
},
stream=False
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
response = client.completions.create(
model=model,
prompt=self._convert_messages_to_prompt_anthropic(prompt_messages),
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response
Handle llm chat response
:param model: model name
:param credentials: credentials
@@ -143,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.completion
content=response.content[0].text
)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
response = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
usage=usage
)
return result
return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
Handle llm chat stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
:return: llm response chunk generator
"""
index = -1
full_assistant_content = ''
return_model = None
input_tokens = 0
output_tokens = 0
finish_reason = None
index = 0
for chunk in response:
content = chunk.completion
if chunk.stop_reason is None and (content is None or content == ''):
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=content if content else '',
)
index += 1
if chunk.stop_reason is not None:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if isinstance(chunk, MessageStartEvent):
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
output_tokens = chunk.usage.output_tokens
finish_reason = chunk.delta.stop_reason
elif isinstance(chunk, MessageStopEvent):
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk(
model=chunk.model,
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=chunk.stop_reason,
index=index + 1,
message=AssistantPromptMessage(
content=''
),
finish_reason=finish_reason,
usage=usage
)
)
else:
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
full_assistant_content += chunk_text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text
)
index = chunk.index
yield LLMResultChunk(
model=chunk.model,
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
index=chunk.index,
message=assistant_prompt_message,
)
)
@@ -234,6 +337,80 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = ""
prompt_message_dicts = []
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
system += message.content + ("\n" if not system else "")
else:
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.

View File

@@ -128,8 +128,10 @@ class BaichuanModel:
'role': message.role,
})
# [baichuan] frequency_penalty must be between 1 and 2
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
parameters['frequency_penalty'] = 1
if 'frequency_penalty' in parameters:
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
parameters['frequency_penalty'] = 1
# turbo api accepts flat parameters
return {
'model': self._model_mapping(model),

View File

@@ -103,7 +103,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
], parameters={
'max_tokens': 1,
}, timeout=60)
except (InvalidAPIKeyError, InvalidAuthenticationError) as e:
except Exception as e:
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],

View File

@@ -27,6 +27,8 @@ parameter_rules:
default: 2048
min: 1
max: 2048
- name: response_format
use_template: response_format
pricing:
input: '0.00'
output: '0.00'

View File

@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""

View File

@@ -2,7 +2,7 @@ provider: jina
label:
en_US: Jina
description:
en_US: Embedding Model Supported
en_US: Embedding and Rerank Model Supported
icon_small:
en_US: icon_s_en.svg
icon_large:
@@ -13,9 +13,10 @@ help:
en_US: Get your API key from Jina AI
zh_Hans: 从 Jina 获取 API Key
url:
en_US: https://jina.ai/embeddings/
en_US: https://jina.ai/
supported_model_types:
- text-embedding
- rerank
configurate_methods:
- predefined-model
provider_credential_schema:

Some files were not shown because too many files have changed in this diff Show More