Compare commits

...

73 Commits

Author SHA1 Message Date
zxhlyh
d70d61b1cb frontend for model runtime (#1861)
Co-authored-by: Joel <iamjoel007@gmail.com>
2024-01-03 00:05:08 +08:00
takatost
d069c668f8 Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
2024-01-02 23:42:00 +08:00
Jyong
e91dd28a76 fix file estimate issue (#1860)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-02 16:25:59 +08:00
Jyong
595e9b25ba Add data clean schedule (#1859)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-02 15:29:18 +08:00
waltcow
06d2d8cea3 Refactor BaseVectorIndex delete method (#1853) 2023-12-30 21:49:01 +08:00
Bowen Liang
936c3cc4d7 ci: Bump Docker Github actions (#1852) 2023-12-30 10:58:28 +08:00
crazywoola
08abbb8dba Feat/add community link to dropdown (#1851) 2023-12-28 18:14:16 +08:00
Joel
972cf3cd01 fix: splitting text ui broken (#1848) 2023-12-27 17:59:50 +08:00
Jyong
da4847c5a8 fix segment update issue (#1844)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-26 16:22:51 +08:00
Jyong
08494058e9 fix file type not support when preview (#1841)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-26 15:06:44 +08:00
takatost
9080ece3fb feat: comment db port to host (#1831) 2023-12-24 15:43:43 +08:00
takatost
438912700c feat: nginx add "restart: always" (#1829) 2023-12-24 15:33:28 +08:00
Charlie.Wei
6b57e4e0ff Fix chitchat lost context (#1828)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2023-12-23 10:05:17 +08:00
Joel
6da3a33e6c fix: selection too long break ui (#1826) 2023-12-22 16:54:18 +08:00
crazywoola
2c8badfea9 Update README.md (#1825) 2023-12-22 15:05:11 +08:00
crazywoola
91182a86bf fix: edited by is missing (#1824) 2023-12-22 14:20:11 +08:00
Bannings
0b7e0cadc0 Fix Azure OpenAI Provider BASE_MODELS (#1813)
Co-authored-by: xifan <xifan@gaoding.com>
2023-12-22 14:17:15 +08:00
Jyong
163515c6e9 fix unstructured requirements (#1821)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-22 10:31:06 +08:00
Joel
40d612ffc7 feat: add roadmap and feedback link (#1816) 2023-12-21 16:17:40 +08:00
Chenhe Gu
88a73ecdea add link to canny to README, plus some rewording (#1814) 2023-12-21 02:04:40 -06:00
Charlie.Wei
64642fabc4 Parse base64 eml file (#1796)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2023-12-21 13:18:58 +08:00
crazywoola
7083a05a25 fix: mail link color (#1812) 2023-12-21 12:44:08 +08:00
Charlie.Wei
9f3ed32d0f Fix azure openai gpt4v&1106 config (#1811)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2023-12-21 12:37:38 +08:00
crazywoola
1521ac5563 feat: add email template for invite new user in workspace (#1810) 2023-12-21 11:09:41 +08:00
Joshua
695246d80a Update README_CN.md (#1804) 2023-12-20 22:43:57 +08:00
Joshua
908164f6d5 Update README_CN.md (#1806) 2023-12-20 22:43:18 +08:00
Joshua
96206b6108 Update README_JA.md (#1805) 2023-12-20 22:42:43 +08:00
Joshua
e2fff7fd87 Update README_ES.md (#1807) 2023-12-20 22:41:26 +08:00
Joshua
53690bfad2 Update README.md (#1803) 2023-12-20 21:50:16 +08:00
Joshua
ae37a7d998 Update README.md (#1802) 2023-12-20 21:48:40 +08:00
Garfield Dai
7b37e05dec feat: add billing switch. (#1789)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-12-20 15:37:57 +08:00
Jyong
022450768f fix gpt 4v upload image issue (#1799)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-20 13:03:08 +08:00
crazywoola
7c5661152e fix: settings/members dropdown ui (#1797) 2023-12-20 09:27:22 +08:00
crazywoola
fb55b3a89a Fix: delete member dropdown not shown (#1794) 2023-12-19 20:01:58 +08:00
Jyong
df1509983c ppt & pptx improve (#1790)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-19 18:11:27 +08:00
Jyong
185c2f86cd Compatible with the situation where there is no user information. (#1792)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-19 17:47:25 +08:00
taokuizu
10fc44e2af fix typo (#1791) 2023-12-19 17:23:49 +08:00
Joel
c3275dfd36 fix: not return annotation author error happens (#1793) 2023-12-19 17:22:54 +08:00
takatost
43741ad5d1 feat: bump version to 0.3.34 (#1788) 2023-12-19 14:08:47 +08:00
Joel
8dec406161 chore: enchance ext name (#1787) 2023-12-19 14:03:24 +08:00
zxhlyh
58f8d74591 fix: unstructured file extension (#1785) 2023-12-19 12:09:48 +08:00
zxhlyh
867fc61b12 fix: web app text (#1784) 2023-12-19 11:45:16 +08:00
Joel
8e2e477a7f chore: enchance annotation ui (#1781) 2023-12-19 10:25:54 +08:00
Joel
9b34f5a9ff feat: unstructured frontend (#1777) 2023-12-18 23:28:25 +08:00
Jyong
5e34f938c1 Feat/add unstructured support (#1780)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 23:24:06 +08:00
Jyong
2fd56cb01c Fix/vdb index issue (#1776)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 21:33:54 +08:00
Joel
4f0e272549 fix: add then eidt annotion cause show bug (#1775) 2023-12-18 19:33:48 +08:00
zxhlyh
1a5279a3ef fix: get billing info in self-hosted edition from current workspace (#1774) 2023-12-18 17:54:16 +08:00
Joel
7775f5785f chore: update annotation reply english i18n (#1773) 2023-12-18 17:13:15 +08:00
Garfield Dai
2de73991ff feat: only tenant owner can subscription. (#1770) 2023-12-18 16:59:31 +08:00
Joel
354d033e60 fix: not owner can not pay (#1772) 2023-12-18 16:54:47 +08:00
Jyong
ebc2cdad2e fix annotation query exception (#1771)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 16:48:34 +08:00
zxhlyh
5bb841935e feat: custom webapp logo (#1766) 2023-12-18 16:25:37 +08:00
Joel
65fd4b39ce feat: annotation management frontend (#1764) 2023-12-18 15:41:24 +08:00
Jyong
96d2de2258 fix annotation reply in universal chat (#1768)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 15:04:17 +08:00
Jyong
a71f2863ac Annotation management (#1767)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-18 13:10:05 +08:00
crazywoola
a9b942981d fix: issue templates not render correctly (#1763) 2023-12-18 09:22:11 +08:00
Garfield Dai
4b1ba2ec21 feat: remove billing config. (#1761) 2023-12-17 13:22:45 +08:00
Qiwen Tong
c09184fd94 update bm25 search properties (#1758)
Co-authored-by: Blade <zhangxiaobin@unixyz.cn>
2023-12-15 12:28:03 +08:00
Charlie.Wei
b0d8d196e1 azure openai add gpt-4-1106-preview、gpt-4-vision-preview models (#1751)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2023-12-14 09:55:30 +08:00
Garfield Dai
7c43123956 feat: can replace logo. (#1752) 2023-12-13 20:21:39 +08:00
zxhlyh
eede84eb9e feat: web app support some feature (#1753) 2023-12-13 20:21:11 +08:00
Joel
b5b20234e9 feat: update pricing (#1749) 2023-12-13 16:41:40 +08:00
Joel
5beb298e47 chore: update term links (#1748) 2023-12-13 15:12:27 +08:00
Garfield Dai
6b499b9a16 remove stripe and anthropic. (#1746) 2023-12-12 17:59:07 +08:00
Chenhe Gu
4c639961f5 add self checks to issues and discussions templates (#1742) 2023-12-11 23:25:17 +08:00
Joel
dfd3f507fb fix: ad block disabled tracking would block ga then can not pay (#1741) 2023-12-11 16:36:58 +08:00
Jyong
d5695b3170 check rerank document is not empty (#1740)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-11 16:12:11 +08:00
crazywoola
994fceece3 fix: qa regex (#1738) 2023-12-11 15:53:37 +08:00
Jyong
8c451eb0e6 fix only full text search in app issue (#1736)
Co-authored-by: jyong <jyong@dify.ai>
2023-12-11 15:34:29 +08:00
Joel
79b4366203 fix: server component use translate errorts lint error (#1732) 2023-12-11 10:15:49 +08:00
Joel
3675d2eae8 fix: prompt null parse var error (#1731) 2023-12-11 10:06:01 +08:00
crazywoola
38b55d2186 fix: default types (#1728) 2023-12-09 23:38:07 +08:00
1042 changed files with 180794 additions and 24661 deletions

View File

@@ -1,11 +1,18 @@
name: "🕷️ Bug report"
description: Report errors or unexpected behavior [please use English :]
description: Report errors or unexpected behavior
labels:
- bug
body:
- type: markdown
- type: checkboxes
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: input
attributes:
label: Dify version

View File

@@ -1,8 +1,16 @@
name: "📚 Documentation Issue"
description: Report issues in our documentation [please use English :]
description: Report issues in our documentation
labels:
- ducumentation
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Provide a description of requested docs changes

View File

@@ -1,8 +1,17 @@
name: "⭐ Feature or enhancement request"
description: Propose something new. [please use English :]
description: Propose something new.
labels:
- enhancement
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Description of the new feature / enhancement

View File

@@ -3,6 +3,15 @@ description: "Request help from the community" [please use English :]
labels:
- help-wanted
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: textarea
attributes:
label: Provide a description of the help you need

View File

@@ -3,9 +3,15 @@ description: Report incorrect translations. [please use English :]
labels:
- translation
body:
- type: markdown
- type: checkboxes
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- type: input
attributes:
label: Dify version

View File

@@ -0,0 +1,58 @@
name: Run Pytest
on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
- feat/model-runtime
jobs:
test:
runs-on: ubuntu-latest
env:
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
CHATGLM_API_BASE: http://a.abc.com:11451
XINFERENCE_SERVER_URL: http://a.abc.com:11451
XINFERENCE_GENERATION_MODEL_UID: generate
XINFERENCE_CHAT_MODEL_UID: chat
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
XINFERENCE_RERANK_MODEL_UID: rerank
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
MOCK_SWITCH: true
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py

View File

@@ -1,38 +0,0 @@
name: Run Pytest
on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/unit_tests

View File

@@ -14,10 +14,10 @@ jobs:
if: github.event.pull_request.draft == false
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v2
@@ -27,7 +27,7 @@ jobs:
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
uses: docker/metadata-action@v5
with:
images: langgenius/dify-api
tags: |
@@ -37,7 +37,7 @@ jobs:
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
- name: Build and push
uses: docker/build-push-action@v4
uses: docker/build-push-action@v5
with:
context: "{{defaultContext}}:api"
platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }}

View File

@@ -14,10 +14,10 @@ jobs:
if: github.event.pull_request.draft == false
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v2
@@ -27,7 +27,7 @@ jobs:
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
uses: docker/metadata-action@v5
with:
images: langgenius/dify-web
tags: |
@@ -37,7 +37,7 @@ jobs:
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
- name: Build and push
uses: docker/build-push-action@v4
uses: docker/build-push-action@v5
with:
context: "{{defaultContext}}:web"
platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }}

View File

@@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
### Provider Integrations
If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
### i18n (Internationalization) Support
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.

View File

@@ -4,7 +4,8 @@
<a href="./README_CN.md">简体中文</a> |
<a href="./README_JA.md">日本語</a> |
<a href="./README_ES.md">Español</a> |
<a href="./README_KL.md">Klingon</a>
<a href="./README_KL.md">Klingon</a> |
<a href="./README_FR.md">Français</a>
</p>
<p align="center">
@@ -20,19 +21,18 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>
[v0.3.31:Surpassing the Assistants API Dify's RAG Demonstrates an Impressive 20% Improvement.](https://dify.ai/blog/dify-ai-rag-technology-upgrade-performance-improvement-qa-accuracy)
**Dify** is an LLM application development platform that has already seen over **100,000** applications built on Dify.AI. It integrates the concepts of Backend as a Service and LLMOps, covering the core tech stack required for building generative AI-native applications, including a built-in RAG engine. With Dify, **you can self-deploy capabilities similar to Assistants API and GPTs based on any LLMs.**
**Dify** is an LLM application development platform that has helped built over **100,000** applications. It integrates BaaS and LLMOps, covering the essential tech stack for building generative AI-native applications, including a built-in RAG engine. Dify allows you to **deploy your own version of Assistants API and GPTs, based on any LLMs.**
![](./images/demo.png)
## Use Cloud Services
Using [Dify.AI Cloud](https://dify.ai) provides all the capabilities of the open-source version, and includes a complimentary 200 GPT trial credits.
[Dify.AI Cloud](https://dify.ai) provides all the capabilities of the open-source version, and includes 200 free requests to OpenAI GPT-3.5.
## Why Dify
Dify features model neutrality and is a complete, engineered tech stack compared to hardcoded development libraries like LangChain. Unlike OpenAI's Assistants API, Dify allows for full local deployment of services.
Dify is model-agnostic and boasts a comprehensive tech stack compared to hardcoded development libraries like LangChain. Unlike OpenAI's Assistants API, Dify allows for full local deployment of services.
| Feature | Dify.AI | Assistants API | LangChain |
|---------|---------|----------------|-----------|
@@ -59,6 +59,10 @@ Dify features model neutrality and is a complete, engineered tech stack compared
## Before You Start
**Star us, and you'll get instant notifications for all new releases on GitHub!**
![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f)
- [Website](https://dify.ai)
- [Docs](https://docs.dify.ai)
- [Deployment Docs](https://docs.dify.ai/getting-started/install-self-hosted)
@@ -104,6 +108,7 @@ If you need to customize the configuration, please refer to the comments in our
We welcome you to contribute to Dify to help make Dify better in various ways, submitting code, issues, new ideas, or sharing the interesting and useful AI applications you have created based on Dify. At the same time, we also welcome you to share Dify at different events, conferences, and social media.
- [Roadmap and Feedback](https://feedback.dify.ai/). Best for: sharing feedback and checking out our feature roadmap.
- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs and errors you encounter using Dify.AI, see the [Contribution Guide](CONTRIBUTING.md).
- [Email Support](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI.
- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.

View File

@@ -4,7 +4,8 @@
<a href="./README_CN.md">简体中文</a> |
<a href="./README_JA.md">日本語</a> |
<a href="./README_ES.md">Español</a> |
<a href="./README_KL.md">Klingon</a>
<a href="./README_KL.md">Klingon</a> |
<a href="./README_FR.md">Français</a>
</p>
<p align="center">
@@ -24,6 +25,10 @@ Dify 是一个 LLM 应用开发平台,已经有超过 10 万个应用基于 Di
![](./images/demo.png)
## 使用云端服务
使用 [Dify.AI Cloud](https://dify.ai) 提供开源版本的所有功能,并包含 200 次 GPT 试用额度。
## 为什么选择 Dify
Dify 具有模型中立性,相较 LangChain 等硬编码开发库 Dify 是一个完整的、工程化的技术栈,而相较于 OpenAI 的 Assistants API 你可以完全将服务部署在本地。
@@ -54,6 +59,10 @@ Dify 具有模型中立性,相较 LangChain 等硬编码开发库 Dify 是一
## 在开始之前
**关注我们,您将立即收到 GitHub 上所有新发布版本的通知!**
![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f)
- [网站](https://dify.ai)
- [文档](https://docs.dify.ai)
- [部署文档](https://docs.dify.ai/getting-started/install-self-hosted)
@@ -111,4 +120,4 @@ docker compose up -d
## License
本仓库遵循 [Dify Open Source License](LICENSE) 开源协议。
本仓库遵循 [Dify Open Source License](LICENSE) 开源协议,该许可证本质上是 Apache 2.0,但有一些额外的限制

View File

@@ -3,7 +3,9 @@
<a href="./README.md">English</a> |
<a href="./README_CN.md">简体中文</a> |
<a href="./README_JA.md">日本語</a> |
<a href="./README_ES.md">Español</a>
<a href="./README_ES.md">Español</a> |
<a href="./README_KL.md">Klingon</a> |
<a href="./README_FR.md">Français</a>
</p>
<p align="center">
@@ -56,6 +58,10 @@ Dify se caracteriza por su neutralidad de modelo y es un conjunto tecnológico c
## Antes de Empezar
**¡Danos una estrella, y recibirás notificaciones instantáneas de todos los nuevos lanzamientos en GitHub!**
![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f)
- [Sitio web](https://dify.ai)
- [Documentación](https://docs.dify.ai)
- [Documentación de Implementación](https://docs.dify.ai/getting-started/install-self-hosted)
@@ -109,4 +115,4 @@ Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En
## Licencia
Este repositorio está disponible bajo la [Licencia de código abierto de Dify](LICENSE).
Este repositorio está disponible bajo la [Licencia de Código Abierto Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales.

120
README_FR.md Normal file
View File

@@ -0,0 +1,120 @@
[![](./images/describe.png)](https://dify.ai)
<p align="center">
<a href="./README.md">English</a> |
<a href="./README_CN.md">简体中文</a> |
<a href="./README_JA.md">日本語</a> |
<a href="./README_ES.md">Español</a> |
<a href="./README_KL.md">Klingon</a> |
<a href="./README_FR.md">Français</a>
</p>
<p align="center">
<a href="https://dify.ai" target="_blank">
<img alt="Static Badge" src="https://img.shields.io/badge/AI-Dify?logo=AI&logoColor=%20%23f5f5f5&label=Dify&labelColor=%20%23155EEF&color=%23EAECF0"></a>
<a href="https://discord.gg/FngNHpbcY7" target="_blank">
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord"
alt="chat on Discord"></a>
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
<img src="https://img.shields.io/twitter/follow/dify_ai?style=social&logo=X"
alt="follow on Twitter"></a>
<a href="https://hub.docker.com/u/langgenius" target="_blank">
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>
**Dify** est une plateforme de développement d'applications LLM qui a déjà vu plus de **100,000** applications construites sur Dify.AI. Elle intègre les concepts de Backend as a Service et LLMOps, couvrant la pile technologique de base requise pour construire des applications natives d'IA générative, y compris un moteur RAG intégré. Avec Dify, **vous pouvez auto-déployer des capacités similaires aux API Assistants et GPT basées sur n'importe quels LLM.**
![](./images/demo.png)
## Utiliser les services cloud
L'utilisation de [Dify.AI Cloud](https://dify.ai) fournit toutes les capacités de la version open source, et comprend un essai gratuit de 200 crédits GPT.
## Pourquoi Dify
Dify présente une neutralité de modèle et est une pile technologique complète et conçue par rapport à des bibliothèques de développement codées en dur comme LangChain. Contrairement à l'API Assistants d'OpenAI, Dify permet un déploiement local complet des services.
| Fonctionnalité | Dify.AI | API Assistants | LangChain |
|---------------|----------|-----------------|------------|
| **Approche de programmation** | Orientée API | Orientée API | Orientée code Python |
| **Stratégie écosystème** | Open source | Fermé et commercial | Open source |
| **Moteur RAG** | Pris en charge | Pris en charge | Non pris en charge |
| **IDE d'invite** | Inclus | Inclus | Aucun |
| **LLM pris en charge** | Grande variété | Seulement GPT | Grande variété |
| **Déploiement local** | Pris en charge | Non pris en charge | Non applicable |
## Fonctionnalités
![](./images/models.png)
**1\. Support LLM**: Intégration avec la famille de modèles GPT d'OpenAI, ou les modèles de la famille open source Llama2. En fait, Dify prend en charge les modèles commerciaux grand public et les modèles open source (déployés localement ou basés sur MaaS).
**2\. IDE d'invite**: Orchestration visuelle d'applications et de services basés sur LLMs avec votre équipe.
**3\. Moteur RAG**: Comprend diverses capacités RAG basées sur l'indexation de texte intégral ou les embeddings de base de données vectorielles, permettant le chargement direct de PDF, TXT et autres formats de texte.
**4\. Agents**: Un framework d'agents basé sur l'appel de fonctions qui permet aux utilisateurs de configurer ce qu'ils voient est ce qu'ils obtiennent. Dify comprend des capacités de plug-in de base comme Google Search.
**5\. Opérations continues**: Surveillez et analysez les journaux et les performances des applications, améliorez en continu les invites, les datasets ou les modèles à l'aide de données de production.
## Avant de commencer
**Étoilez-nous, et vous recevrez des notifications instantanées pour toutes les nouvelles sorties sur GitHub !**
![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f)
- [Site web](https://dify.ai)
- [Documentation](https://docs.dify.ai)
- [Documentation de déploiement](https://docs.dify.ai/getting-started/install-self-hosted)
- [FAQ](https://docs.dify.ai/getting-started/faq)
## Installer la version Communauté
### Configuration système
Avant d'installer Dify, assurez-vous que votre machine répond aux exigences minimales suivantes:
- CPU >= 2 cœurs
- RAM >= 4 Go
### Démarrage rapide
La façon la plus simple de démarrer le serveur Dify est d'exécuter notre fichier [docker-compose.yml](docker/docker-compose.yaml). Avant d'exécuter la commande d'installation, assurez-vous que [Docker](https://docs.docker.com/get-docker/) et [Docker Compose](https://docs.docker.com/compose/install/) sont installés sur votre machine:
```bash
cd docker
docker compose up -d
```
Après l'exécution, vous pouvez accéder au tableau de bord Dify dans votre navigateur à l'adresse [http://localhost/install](http://localhost/install) et démarrer le processus d'installation initiale.
### Chart Helm
Un grand merci à @BorisPolonsky pour nous avoir fourni une version [Helm Chart](https://helm.sh/) qui permet le déploiement de Dify sur Kubernetes.
Vous pouvez accéder à https://github.com/BorisPolonsky/dify-helm pour des informations de déploiement.
### Configuration
Si vous avez besoin de personnaliser la configuration, veuillez vous référer aux commentaires de notre fichier [docker-compose.yml](docker/docker-compose.yaml) et définir manuellement la configuration de l'environnement. Après avoir apporté les modifications, veuillez exécuter à nouveau `docker-compose up -d`. Vous trouverez la liste complète des variables d'environnement dans notre [documentation](https://docs.dify.ai/getting-started/install-self-hosted/environments).
## Historique d'étoiles
[![Diagramme de l'historique des étoiles](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)
## Communauté & Support
Nous vous invitons à contribuer à Dify pour aider à améliorer Dify de diverses manières, en soumettant du code, des problèmes, de nouvelles idées ou en partageant les applications d'IA intéressantes et utiles que vous avez créées sur la base de Dify. En même temps, nous vous invitons également à partager Dify lors de différents événements, conférences et réseaux sociaux.
- [Problèmes GitHub](https://github.com/langgenius/dify/issues). Idéal pour : les bogues et les erreurs que vous rencontrez en utilisant Dify.AI, voir le [Guide de contribution](CONTRIBUTING.md).
- [Support par courriel](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Idéal pour : les questions que vous avez au sujet de l'utilisation de Dify.AI.
- [Discord](https://discord.gg/FngNHpbcY7). Idéal pour : partager vos applications et discuter avec la communauté.
- [Twitter](https://twitter.com/dify_ai). Idéal pour : partager vos applications et discuter avec la communauté.
- [Licence commerciale](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Idéal pour : les demandes commerciales de licence de Dify.AI pour un usage commercial.
## Divulgation de la sécurité
Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Envoyez plutôt vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée.
## Licence
Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement Apache 2.0 avec quelques restrictions supplémentaires.

View File

@@ -4,7 +4,8 @@
<a href="./README_CN.md">简体中文</a> |
<a href="./README_JA.md">日本語</a> |
<a href="./README_ES.md">Español</a> |
<a href="./README_KL.md">Klingon</a>
<a href="./README_KL.md">Klingon</a> |
<a href="./README_FR.md">Français</a>
</p>
<p align="center">
@@ -24,6 +25,8 @@
Please note that translating complex technical terms can sometimes result in slight variations in meaning due to differences in language nuances.
![](./images/demo.png)
## クラウドサービスの利用
[Dify.AI Cloud](https://dify.ai) を使用すると、オープンソース版の全機能を利用でき、さらに200GPTのトライアルクレジットが無料で提供されます。
@@ -41,9 +44,26 @@ Difyはモデルニュートラルであり、LangChainのようなハードコ
| **サポートされるLLMs** | 豊富な種類 | GPTのみ | 豊富な種類 |
| **ローカルデプロイメント** | サポート済み | 非サポート | 該当なし |
## 機能
![](./images/models.png)
**1\. LLMサポート**: OpenAIのGPTファミリーモデルやLlama2ファミリーのオープンソースモデルとの統合。 実際、Difyは主要な商用モデルとオープンソースモデル(ローカルでデプロイまたはMaaSベース)をサポートしています。
**2\. プロンプトIDE**: チームとのLLMベースのアプリケーションとサービスの視覚的なオーケストレーション。
**3\. RAGエンジン**: フルテキストインデックスまたはベクトルデータベース埋め込みに基づくさまざまなRAG機能を含み、PDF、TXT、その他のテキストフォーマットの直接アップロードを可能にします。
**4\. エージェント**: ユーザーが sees what they get を設定できる関数呼び出しベースのエージェントフレームワーク。 Difyには、Google検索などの基本的なプラグイン機能が含まれています。
**5\. 継続的運用**: アプリケーションログとパフォーマンスを監視および分析し、運用データを使用してプロンプト、データセット、またはモデルを継続的に改善します。
## 開始する前に
**私たちをスターして、GitHub上でのすべての新しいリリースに対する即時通知を受け取ります**
![私たちをスターして](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f)
- [Website](https://dify.ai)
- [Docs](https://docs.dify.ai)
- [Deployment Docs](https://docs.dify.ai/getting-started/install-self-hosted)
@@ -100,4 +120,4 @@ Difyに貢献していただき、コードの提出、問題の報告、新し
## ライセンス
このリポジトリは、[Dify Open Source License](LICENSE) のもとで利用できます。
このリポジトリは、基本的にApache 2.0にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用できます。

View File

@@ -4,7 +4,8 @@
<a href="./README_CN.md">简体中文</a> |
<a href="./README_JA.md">日本語</a> |
<a href="./README_ES.md">Español</a> |
<a href="./README_KL.md">Klingon</a>
<a href="./README_KL.md">Klingon</a> |
<a href="./README_FR.md">Français</a>
</p>
<p align="center">
@@ -57,6 +58,10 @@ Dify Daq rIn neutrality 'ej Hoch, LangChain tInHar HubwI'. maH Daqbe'law' Qawqar
## Do'wI' qabmey lo'taH
**maHvaD jatlhchugh, GitHub Daq Hoch chu' ghompu'vam tIqel yInob!**
![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f)
- [Website](https://dify.ai)
- [Docs](https://docs.dify.ai)
- [lo'taHmoH Docs](https://docs.dify.ai/getting-started/install-self-hosted)

View File

@@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
@@ -119,16 +117,6 @@ HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
# Stripe configuration
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Billing configuration
BILLING_API_URL=http://127.0.0.1:8000/v1
BILLING_API_SECRET_KEY=
STRIPE_WEBHOOK_BILLING_SECRET=
ETL_TYPE=dify
UNSTRUCTURED_API_URL=

View File

@@ -4,6 +4,21 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Celery",
"type": "python",
"request": "launch",
"module": "celery",
"justMyCode": true,
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
"envFile": "${workspaceFolder}/.env",
"env": {
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"console": "integratedTerminal"
},
{
"name": "Python: Flask",
"type": "python",

View File

@@ -34,9 +34,6 @@ RUN apt-get update \
COPY --from=base /pkg /usr/local
COPY . /app/api/
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
ENV TRANSFORMERS_OFFLINE true
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh

View File

@@ -6,9 +6,12 @@ from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
# if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
import langchain
langchain.verbose = True
import time
import logging
@@ -18,9 +21,8 @@ import threading
from flask import Flask, request, Response
from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
from extensions.ext_database import db
from extensions.ext_login import login_manager
@@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app)
register_commands(app)
hosted.init_app(app)
return app
@@ -95,8 +95,8 @@ def initialize_extensions(app):
ext_celery.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_hosting_provider.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)
# Flask-Login configuration
@@ -106,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if not auth_header:
auth_token = request.args.get('_token')
if not auth_token:
raise Unauthorized('Invalid Authorization token.')
else:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')

View File

@@ -12,23 +12,19 @@ import qdrant_client
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
import secrets
import base64
@@ -327,6 +323,8 @@ def create_qdrant_indexes():
except NotFound:
break
model_manager = ModelManager()
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
@@ -334,19 +332,23 @@ def create_qdrant_indexes():
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except Exception:
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
embedding_model = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
@@ -752,6 +754,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
pbar.update(len(data_batch))
@click.command('add-annotation-question-field-value', help='add annotation question value')
def add_annotation_question_field_value():
click.echo(click.style('Start add annotation question value.', fg='green'))
message_annotations = db.session.query(MessageAnnotation).all()
message_annotation_deal_count = 0
if message_annotations:
for message_annotation in message_annotations:
try:
if message_annotation.message_id and not message_annotation.question:
message = db.session.query(Message).filter(
Message.id == message_annotation.message_id
).first()
message_annotation.question = message.query
db.session.add(message_annotation)
db.session.commit()
message_annotation_deal_count += 1
except Exception as e:
click.echo(
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@@ -766,3 +792,4 @@ def register_commands(app):
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
app.cli.add_command(add_annotation_question_field_value)

View File

@@ -1,11 +1,8 @@
# -*- coding:utf-8 -*-
import os
from datetime import timedelta
import dotenv
from extensions.ext_database import db
from extensions.ext_redis import redis_client
dotenv.load_dotenv()
@@ -44,15 +41,11 @@ DEFAULTS = {
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'CLEAN_DAY_SETTING': 30,
@@ -61,7 +54,10 @@ DEFAULTS = {
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64',
'INVITE_EXPIRY_HOURS': 72
'INVITE_EXPIRY_HOURS': 72,
'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False',
'ETL_TYPE': 'dify',
}
@@ -91,7 +87,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.3.33"
self.CURRENT_VERSION = "0.4.0"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -268,8 +264,6 @@ class Config:
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
@@ -281,14 +275,15 @@ class Config:
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.ETL_TYPE = get_env('ETL_TYPE')
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
class CloudEditionConfig(Config):
@@ -302,6 +297,3 @@ class CloudEditionConfig(Config):
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@@ -6,10 +6,10 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
api = ExternalApi(bp)
# Import other controllers
from . import extension, setup, version, apikey, admin
from . import extension, setup, version, apikey, admin, feature
# Import app controllers
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate
@@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
from .workspace import workspace, members, model_providers, account, tool_providers, models
# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
@@ -26,7 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m
# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio
# Import webhook controllers
from .webhook import stripe
from .billing import billing

View File

@@ -0,0 +1,290 @@
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, marshal
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import NoFileUploadedError
from controllers.console.datasets.error import TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \
annotation_hit_history_fields
from libs.login import login_required
from services.annotation_service import AppAnnotationService
from flask import request
class AnnotationReplyActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id, action):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
args = parser.parse_args()
if action == 'enable':
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == 'disable':
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError('Unsupported annotation reply action')
return result, 200
class AppAnnotationSettingDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
class AppAnnotationSettingUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id, annotation_setting_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument('score_threshold', required=True, type=float, location='json')
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id, action):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
keyword = request.args.get('keyword', default=None, type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
'data': marshal(annotation_list, annotation_fields),
'has_more': len(annotation_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
class AnnotationExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {
'data': marshal(annotation_list, annotation_fields)
}
return response, 200
class AnnotationCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id, annotation_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {'result': 'success'}, 200
class AnnotationBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
class AnnotationBatchImportStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ''
if job_status == 'error':
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
error_msg = redis_client.get(indexing_error_msg_key).decode()
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationHitHistoryListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id, annotation_id):
# The role of the current user in the table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
app_id = str(app_id)
annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
page, limit)
response = {
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
'has_more': len(annotation_hit_history_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
api.add_resource(AnnotationReplyActionStatusApi,
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')

View File

@@ -4,6 +4,10 @@ import logging
from datetime import datetime
from flask_login import current_user
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden
@@ -13,9 +17,7 @@ from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
@@ -73,39 +75,41 @@ class AppListApi(Resource):
raise Forbidden()
try:
default_model = ModelFactory.get_text_generation_model(
tenant_id=current_user.current_tenant_id
provider_manager = ProviderManager()
default_model_entity = provider_manager.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
default_model = None
default_model_entity = None
except Exception as e:
logging.exception(e)
default_model = None
default_model_entity = None
if args['model_config'] is not None:
# validate config
model_config_dict = args['model_config']
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
model_config_dict["model"]["provider"]
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
model_config_dict["model"]["name"] = default_model.name
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=model_config_dict,
mode=args['mode']
app_mode=args['mode']
)
app = App(
@@ -129,21 +133,27 @@ class AppListApi(Resource):
app_model_config = AppModelConfig(**model_config_template['model_config'])
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
app_model_config.model_dict["provider"]
)
model_manager = ModelManager()
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.model_provider.provider_name
model_dict['name'] = default_model.name
app_model_config.model = json.dumps(model_dict)
try:
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = model_instance.provider
model_dict['name'] = model_instance.model
app_model_config.model = json.dumps(model_dict)
app.name = args['name']
app.mode = args['mode']

View File

@@ -2,6 +2,8 @@
import logging
from flask import request
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from werkzeug.exceptions import InternalServerError
@@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
@@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e

View File

@@ -5,6 +5,10 @@ from typing import Generator, Union
import flask_login
from flask import Response, stream_with_context
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
@@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from flask_restful import Resource, reqparse
@@ -56,7 +58,7 @@ class CompletionMessageApi(Resource):
app_model=app_model,
user=account,
args=args,
from_source='console',
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming,
is_model_config_override=True
)
@@ -75,8 +77,7 @@ class CompletionMessageApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource):
account = flask_login.current_user
PubHandler.stop(account, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200
@@ -132,7 +133,7 @@ class ChatMessageApi(Resource):
app_model=app_model,
user=account,
args=args,
from_source='console',
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming,
is_model_config_override=True
)
@@ -151,8 +152,7 @@ class ChatMessageApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
@@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource):
account = flask_login.current_user
PubHandler.stop(account, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200

View File

@@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400

View File

@@ -1,4 +1,6 @@
from flask_login import current_user
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from flask_restful import Resource, reqparse
@@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.generator.llm_generator import LLMGenerator
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
class RuleGenerateApi(Resource):
@@ -36,8 +37,7 @@ class RuleGenerateApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
return rules

View File

@@ -6,22 +6,24 @@ from flask import Response, stream_with_context
from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, fields
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from fields.conversation_fields import message_detail_fields
from fields.conversation_fields import message_detail_fields, annotation_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
@@ -151,44 +153,24 @@ class MessageAnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
app_id = str(app_id)
# get app info
app = _get_app(app_id)
parser = reqparse.RequestParser()
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
parser.add_argument('content', type=str, location='json')
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
message_id = str(args['message_id'])
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app.id
).first()
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
if annotation:
annotation.content = args['content']
else:
annotation = MessageAnnotation(
app_id=app.id,
conversation_id=message.conversation_id,
message_id=message.id,
content=args['content'],
account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
return {'result': 'success'}
return annotation
class MessageAnnotationCountApi(Resource):
@@ -227,7 +209,13 @@ class MessageMoreLikeThisApi(Resource):
app_model = _get_app(app_id, 'completion')
try:
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@@ -239,8 +227,7 @@ class MessageMoreLikeThisApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -268,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -309,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")

View File

@@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
"""Modify app model config"""
app_id = str(app_id)
app_model = _get_app(app_id)
app = _get_app(app_id)
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json,
mode=app_model.mode
app_mode=app.mode
)
new_app_model_config = AppModelConfig(
app_id=app_model.id,
app_id=app.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config)
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app.app_model_config_id = new_app_model_config.id
db.session.commit()
app_model_config_was_updated.send(
app_model,
app,
app_model_config=new_app_model_config
)

View File

@@ -1,9 +1,5 @@
import stripe
import os
from flask_restful import Resource, reqparse
from flask_login import current_user
from flask import current_app, request
from controllers.console import api
from controllers.console.setup import setup_required
@@ -13,20 +9,6 @@ from libs.login import login_required
from services.billing_service import BillingService
class BillingInfo(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
edition = current_app.config['EDITION']
if edition != 'CLOUD':
return {"enabled": False}
return BillingService.get_info(current_user.current_tenant_id)
class Subscription(Resource):
@setup_required
@@ -40,7 +22,12 @@ class Subscription(Resource):
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args()
return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
BillingService.is_tenant_owner(current_user)
return BillingService.get_subscription(args['plan'],
args['interval'],
current_user.email,
current_user.current_tenant_id)
class Invoices(Resource):
@@ -50,36 +37,9 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
BillingService.is_tenant_owner(current_user)
return BillingService.get_invoices(current_user.email)
class StripeBillingWebhook(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
BillingService.process_event(event)
return 'success', 200
api.add_resource(BillingInfo, '/billing/info')
api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, '/billing/invoices')
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')

View File

@@ -4,6 +4,8 @@ from flask import request, current_app
from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
@@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.models.entity.model_params import ModelType
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
@@ -23,7 +24,6 @@ from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
from models.model import UploadFile, ApiToken
from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService
def _validate_name(name):
@@ -55,16 +55,20 @@ class DatasetListApi(Resource):
current_user.current_tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0:
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
# f"in the Settings -> Model Provider.")
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
@@ -75,6 +79,7 @@ class DatasetListApi(Resource):
item['embedding_available'] = False
else:
item['embedding_available'] = True
response = {
'data': data,
'has_more': len(datasets) == limit,
@@ -130,13 +135,20 @@ class DatasetApi(Resource):
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
# check embedding setting
provider_service = ProviderService()
# get valid model list
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data['indexing_technique'] == 'high_quality':
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:

View File

@@ -2,8 +2,12 @@
from datetime import datetime
from typing import List
from flask import request, current_app
from flask import request
from flask_login import current_user
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc
@@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.indexing_runner import IndexingRunner
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields
@@ -272,10 +275,12 @@ class DatasetInitApi(Resource):
args = parser.parse_args()
if args['indexing_technique'] == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
model_manager = ModelManager()
model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
except LLMBadRequestError:
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
@@ -410,7 +415,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if dataset.data_source_type == 'upload_file':
file_details = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id in info_list
UploadFile.id.in_(info_list)
).all()
if file_details is None:

View File

@@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from libs.login import login_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource):
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
@@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource):
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
@@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(

View File

@@ -69,5 +69,20 @@ class FilePreviewApi(Resource):
return {'content': text}
class FileSupportTypeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
etl_type = current_app.config['ETL_TYPE']
if etl_type == 'Unstructured':
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
else:
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
return {'allowed_extensions': allowed_extensions}
api.add_resource(FileApi, '/files/upload')
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
api.add_resource(FileSupportTypeApi, '/files/support-type')

View File

@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService

View File

@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e

View File

@@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.helper import uuid_value
from services.completion_service import CompletionService
@@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource):
app_model=app_model,
user=current_user,
args=args,
from_source='console',
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
@@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource):
if app_model.mode != 'completion':
raise NotCompletionAppError()
PubHandler.stop(current_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200
@@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource):
app_model=app_model,
user=current_user,
args=args,
from_source='console',
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
@@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource):
if app_model.mode != 'chat':
raise NotChatAppError()
PubHandler.stop(current_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200
@@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

View File

@@ -73,7 +73,7 @@ class ConversationRenameApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
try:

View File

@@ -5,7 +5,7 @@ from typing import Generator, Union
from flask import stream_with_context, Response
from flask_login import current_user
from flask_restful import reqparse, fields, marshal_with
from flask_restful import reqparse, marshal_with
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound, InternalServerError
@@ -13,12 +13,14 @@ import services
from controllers.console import api
from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs.helper import uuid_value, TimestampField
from libs.helper import uuid_value
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
@@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")

View File

@@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@@ -0,0 +1,14 @@
from flask_restful import Resource
from flask_login import current_user
from . import api
from services.feature_service import FeatureService
class FeatureApi(Resource):
def get(self):
return FeatureService.get_features(current_user.current_tenant_id).dict()
api.add_resource(FeatureApi, '/features')

View File

@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e

View File

@@ -12,9 +12,10 @@ from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from services.completion_service import CompletionService
@@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource):
app_model=app_model,
user=current_user,
args=args,
from_source='console',
invoke_from=InvokeFrom.EXPLORE,
streaming=True,
is_model_config_override=True,
)
@@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource):
class UniversalChatStopApi(UniversalChatResource):
def post(self, universal_app, task_id):
PubHandler.stop(current_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200
@@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

View File

@@ -66,7 +66,7 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
try:

View File

@@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value, TimestampField
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")

View File

@@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw
}
@marshal_with(parameters_fields)
@@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
}

View File

@@ -1,61 +0,0 @@
import logging
import stripe
from flask import request, current_app
from flask_restful import Resource
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import only_edition_cloud
from services.provider_checkout_service import ProviderCheckoutService
class StripeWebhookApi(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
# Handle the checkout.session.completed event
if event['type'] == 'checkout.session.completed':
logging.debug(event['data']['object']['id'])
logging.debug(event['data']['object']['amount_subtotal'])
logging.debug(event['data']['object']['currency'])
logging.debug(event['data']['object']['payment_intent'])
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])
session = stripe.checkout.Session.retrieve(
event['data']['object']['id'],
expand=['line_items'],
)
logging.debug(session.line_items['data'][0]['quantity'])
# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()
try:
provider_checkout_service.fulfill_provider_order(event, session.line_items)
except Exception as e:
logging.debug(str(e))
return 'success', 200
return 'success', 200
api.add_resource(StripeWebhookApi, '/webhook/stripe')

View File

@@ -1,16 +1,19 @@
import io
from flask import send_file
from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import CredentialsValidateFailedError
from services.provider_checkout_service import ProviderCheckoutService
from services.provider_service import ProviderService
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
class ModelProviderListApi(Resource):
@@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
parser.add_argument('model_type', type=str, required=False, nullable=True,
choices=[mt.value for mt in ModelType], location='args')
args = parser.parse_args()
provider_service = ProviderService()
provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(
tenant_id=tenant_id,
model_type=args.get('model_type')
)
return provider_list
return jsonable_encoder({"data": provider_list})
class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credentials(
tenant_id=tenant_id,
provider=provider
)
return {
"credentials": credentials
}
class ModelProviderValidateApi(Resource):
@@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
def post(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
result = True
error = None
try:
provider_service.custom_provider_config_validate(
provider_name=provider_name,
config=args['config']
model_provider_service.provider_credentials_validate(
tenant_id=tenant_id,
provider=provider,
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
result = False
@@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
return response
class ModelProviderUpdateApi(Resource):
class ModelProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
model_provider_service = ModelProviderService()
try:
provider_service.save_custom_provider_config(
model_provider_service.save_provider_credentials(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
config=args['config']
provider=provider,
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
@@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
def delete(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
provider_service = ProviderService()
provider_service.delete_custom_provider(
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credentials(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
provider=provider
)
return {'result': 'success'}, 204
class ModelProviderModelValidateApi(Resource):
class ModelProviderIconApi(Resource):
"""
Get model provider icon
"""
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
result = True
error = None
try:
provider_service.custom_provider_model_config_validate(
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class ModelProviderModelUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
try:
provider_service.add_or_save_custom_provider_model_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type'],
config=args['config']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_name: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.delete_custom_provider_model(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_name=args['model_name'],
model_type=args['model_type']
def get(self, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
provider=provider,
icon_type=icon_type,
lang=lang
)
return {'result': 'success'}, 204
return send_file(io.BytesIO(icon), mimetype=mimetype)
class PreferredProviderTypeUpdateApi(Resource):
@@ -203,88 +159,50 @@ class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
choices=['system', 'custom'], location='json')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.switch_preferred_provider(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(
tenant_id=tenant_id,
provider=provider,
preferred_provider_type=args['preferred_provider_type']
)
return {'result': 'success'}
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
provider_service = ProviderService()
try:
parameter_rules = provider_service.get_model_parameter_rules(
tenant_id=current_user.current_tenant_id,
model_provider_name=provider_name,
model_name=args['model_name'],
model_type='text-generation'
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"Current Text Generation Model is invalid. Please switch to the available model.")
rules = {
k: {
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default,
'precision': v.precision
}
for k, v in vars(parameter_rules).items()
}
return rules
class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
provider_service = ProviderCheckoutService()
provider_checkout = provider_service.create_checkout(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
account=current_user
)
def get(self, provider: str):
if provider != 'anthropic':
raise ValueError(f'provider name {provider} is invalid')
return {
'url': provider_checkout.get_checkout_url()
}
data = BillingService.get_model_provider_payment_link(provider_name=provider,
tenant_id=current_user.current_tenant_id,
account_id=current_user.id)
return data
class ModelProviderFreeQuotaSubmitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_name: str):
provider_service = ProviderService()
result = provider_service.free_quota_submit(
def post(self, provider: str):
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_submit(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
provider=provider
)
return result
@@ -294,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
args = parser.parse_args()
provider_service = ProviderService()
result = provider_service.free_quota_qualification_verify(
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
provider=provider,
token=args['token']
)
@@ -310,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
api.add_resource(ModelProviderModelValidateApi,
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
api.add_resource(ModelProviderModelUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/models')
api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
'<string:icon_type>/<string:lang>')
api.add_resource(PreferredProviderTypeUpdateApi,
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
'/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
'/workspaces/current/model-providers/<string:provider>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
'/workspaces/current/model-providers/<string:provider>/free-quota-submit')
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
'/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')

View File

@@ -1,16 +1,17 @@
import logging
from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse
from flask_restful import reqparse, Resource
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from models.provider import ProviderType
from services.provider_service import ProviderService
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.model_provider_service import ModelProviderService
class DefaultModelApi(Resource):
@@ -21,52 +22,20 @@ class DefaultModelApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
choices=[mt.value for mt in ModelType], location='args')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
provider_service = ProviderService()
default_model = provider_service.get_default_model_of_model_type(
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id,
model_type=args['model_type']
)
if not default_model:
return None
model_provider = ModelProviderFactory.get_preferred_model_provider(
tenant_id,
default_model.provider_name
)
if not model_provider:
return {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': default_model.provider_name
}
}
provider = model_provider.provider
rst = {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': provider.provider_name,
'provider_type': provider.provider_type
}
}
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
if provider.provider_type == ProviderType.SYSTEM.value:
rst['model_provider']['quota_type'] = provider.quota_type
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
rst['model_provider']['quota_limit'] = provider.quota_limit
rst['model_provider']['quota_used'] = provider.quota_used
return rst
return jsonable_encoder({
"data": default_model_entity
})
@setup_required
@login_required
@@ -76,15 +45,26 @@ class DefaultModelApi(Resource):
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
model_settings = args['model_settings']
for model_setting in model_settings:
if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
raise ValueError('invalid model type')
if 'provider' not in model_setting:
continue
if 'model' not in model_setting:
raise ValueError('invalid model')
try:
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_provider_service.update_default_model_of_model_type(
tenant_id=tenant_id,
model_type=model_setting['model_type'],
provider_name=model_setting['provider_name'],
model_name=model_setting['model_name']
provider=model_setting['provider'],
model=model_setting['model']
)
except Exception:
logging.warning(f"{model_setting['model_type']} save error")
@@ -92,22 +72,198 @@ class DefaultModelApi(Resource):
return {'result': 'success'}
class ValidModelApi(Resource):
class ModelProviderModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(
tenant_id=tenant_id,
provider=provider
)
return jsonable_encoder({
"data": models
})
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
try:
model_provider_service.save_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
return {'result': 'success'}, 204
class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='args')
args = parser.parse_args()
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_model_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=args['model_type'],
model=args['model']
)
return {
"credentials": credentials
}
class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
result = True
error = None
try:
model_provider_service.model_credentials_validate(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
tenant_id=tenant_id,
provider=provider,
model=args['model']
)
return jsonable_encoder({
"data": parameter_rules
})
class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, model_type):
ModelType.value_of(model_type)
tenant_id = current_user.current_tenant_id
provider_service = ProviderService()
valid_models = provider_service.get_valid_model_list(
tenant_id=current_user.current_tenant_id,
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=tenant_id,
model_type=model_type
)
return valid_models
return jsonable_encoder({
"data": models
})
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
api.add_resource(ModelProviderModelCredentialApi,
'/workspaces/current/model-providers/<string:provider>/models/credentials')
api.add_resource(ModelProviderModelValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')

View File

@@ -1,131 +0,0 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType
from services.provider_service import ProviderService
class ProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
"""
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
provider_service = ProviderService()
provider_info_list = provider_service.get_provider_list(tenant_id)
provider_list = [
{
'provider_name': p['provider_name'],
'provider_type': p['provider_type'],
'is_valid': p['is_valid'],
'last_used': p['last_used'],
'is_enabled': p['is_valid'],
**({
'quota_type': p['quota_type'],
'quota_limit': p['quota_limit'],
'quota_used': p['quota_used']
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
if p['config'] else None
}
for name, provider_info in provider_info_list.items()
for p in provider_info['providers']
]
return provider_list
class ProviderTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
provider_service = ProviderService()
try:
provider_service.save_custom_provider_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
result = True
error = None
try:
provider_service.custom_provider_config_validate(
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
endpoint='workspaces_current_providers_token') # PUT for updating provider token
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list

View File

@@ -10,12 +10,15 @@ from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.setup import setup_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, UnsupportedFileTypeError
from libs.helper import TimestampField
from extensions.ext_database import db
from models.account import Tenant
import services
from services.account_service import TenantService
from services.workspace_service import WorkspaceService
from services.file_service import FileService
provider_fields = {
'provider_name': fields.String,
@@ -31,9 +34,9 @@ tenant_fields = {
'status': fields.String,
'created_at': TimestampField,
'role': fields.String,
'providers': fields.List(fields.Nested(provider_fields)),
'in_trial': fields.Boolean,
'trial_end_reason': fields.String,
'custom_config': fields.Raw(attribute='custom_config'),
}
tenants_fields = {
@@ -130,6 +133,61 @@ class SwitchWorkspaceApi(Resource):
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
class CustomConfigWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('workspace_custom')
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('remove_webapp_brand', type=bool, location='json')
parser.add_argument('replace_webapp_logo', type=str, location='json')
args = parser.parse_args()
custom_config_dict = {
'remove_webapp_brand': args['remove_webapp_brand'],
'replace_webapp_logo': args['replace_webapp_logo'],
}
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant.custom_config_dict = custom_config_dict
db.session.commit()
return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
class WebappLogoWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('workspace_custom')
def post(self):
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
extension = file.filename.split('.')[-1]
if extension.lower() not in ['svg', 'png']:
raise UnsupportedFileTypeError()
try:
upload_file = FileService.upload_file(file, current_user, True)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return { 'id': upload_file.id }, 201
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
@@ -137,3 +195,5 @@ api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all ten
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')

View File

@@ -5,7 +5,7 @@ from flask import current_app, abort
from flask_login import current_user
from controllers.console.workspace.error import AccountNotInitializedError
from services.billing_service import BillingService
from services.feature_service import FeatureService
def account_initialization_required(view):
@@ -49,18 +49,23 @@ def cloud_edition_billing_resource_check(resource: str,
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
if current_app.config['EDITION'] == 'CLOUD':
tenant_id = current_user.current_tenant_id
billing_info = BillingService.get_info(tenant_id)
members = billing_info['members']
apps = billing_info['apps']
vector_space = billing_info['vector_space']
features = FeatureService.get_features(current_user.current_tenant_id)
if resource == 'members' and 0 < members['limit'] <= members['size']:
if features.billing.enabled:
members = features.members
apps = features.apps
vector_space = features.vector_space
annotation_quota_limit = features.annotation_quota_limit
if resource == 'members' and 0 < members.limit <= members.size:
abort(403, error_msg)
elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
elif resource == 'apps' and 0 < apps.limit <= apps.size:
abort(403, error_msg)
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
abort(403, error_msg)
elif resource == 'workspace_custom' and not features.can_replace_logo:
abort(403, error_msg)
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
abort(403, error_msg)
else:
return view(*args, **kwargs)

View File

@@ -1,10 +1,12 @@
from flask import request, Response
from flask_restful import Resource
from werkzeug.exceptions import NotFound
import services
from controllers.files import api
from libs.exception import BaseHTTPException
from services.file_service import FileService
from services.account_service import TenantService
class ImagePreviewApi(Resource):
@@ -29,9 +31,30 @@ class ImagePreviewApi(Resource):
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
class WorkspaceWebappLogoApi(Resource):
def get(self, workspace_id):
workspace_id = str(workspace_id)
custom_config = TenantService.get_custom_config(workspace_id)
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
if not webapp_logo_file_id:
raise NotFound(f'webapp logo is not found')
try:
generator, mimetype = FileService.get_public_image_preview(
webapp_logo_file_id,
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo')
class UnsupportedFileTypeError(BaseHTTPException):

View File

@@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
ProviderNotSupportSpeechToTextError
from controllers.service_api.wraps import AppApiResource
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
@@ -49,8 +49,7 @@ class AudioApi(AppApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e

View File

@@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError
from controllers.service_api.wraps import AppApiResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from services.completion_service import CompletionService
@@ -47,7 +48,7 @@ class CompletionApi(AppApiResource):
app_model=app_model,
user=end_user,
args=args,
from_source='api',
invoke_from=InvokeFrom.SERVICE_API,
streaming=streaming,
)
@@ -65,8 +66,7 @@ class CompletionApi(AppApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource):
if app_model.mode != 'completion':
raise AppUnavailableError()
PubHandler.stop(end_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
@@ -98,7 +98,7 @@ class ChatApi(AppApiResource):
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
args = parser.parse_args()
@@ -112,7 +112,7 @@ class ChatApi(AppApiResource):
app_model=app_model,
user=end_user,
args=args,
from_source='api',
invoke_from=InvokeFrom.SERVICE_API,
streaming=streaming
)
@@ -130,8 +130,7 @@ class ChatApi(AppApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource):
if app_model.mode != 'chat':
raise NotChatAppError()
PubHandler.stop(end_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
@@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

View File

@@ -4,11 +4,11 @@ import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from fields.dataset_fields import dataset_detail_fields
from services.dataset_service import DatasetService
from services.provider_service import ProviderService
def _validate_name(name):
@@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource):
datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':

View File

@@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from libs.login import current_user
from core.model_providers.error import ProviderTokenNotInitError
from core.errors.error import ProviderTokenNotInitError
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from models.dataset import Dataset, Document, DocumentSegment

View File

@@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment
@@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
@@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
@@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource):
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(

View File

@@ -11,8 +11,7 @@ from libs.login import _get_user
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, Account
from models.model import ApiToken, App
from services.billing_service import BillingService
from services.feature_service import FeatureService
def validate_app_token(view=None):
def decorator(view):
@@ -46,19 +45,19 @@ def cloud_edition_billing_resource_check(resource: str,
error_msg: str = "You have reached the limit of your subscription."):
def interceptor(view):
def decorated(*args, **kwargs):
if current_app.config['EDITION'] == 'CLOUD':
api_token = validate_and_get_api_token(api_token_type)
billing_info = BillingService.get_info(api_token.tenant_id)
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
members = billing_info['members']
apps = billing_info['apps']
vector_space = billing_info['vector_space']
if features.billing.enabled:
members = features.members
apps = features.apps
vector_space = features.vector_space
if resource == 'members' and 0 < members['limit'] <= members['size']:
if resource == 'members' and 0 < members.limit <= members.size:
raise Unauthorized(error_msg)
elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
elif resource == 'apps' and 0 < apps.limit <= apps.size:
raise Unauthorized(error_msg)
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
raise Unauthorized(error_msg)
else:
return view(*args, **kwargs)

View File

@@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
@@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'annotation_reply': app_model_config.annotation_reply_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,

View File

@@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.web.wraps import WebApiResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
@@ -51,8 +51,7 @@ class AudioApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e

View File

@@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from services.completion_service import CompletionService
@@ -44,7 +45,7 @@ class CompletionApi(WebApiResource):
app_model=app_model,
user=end_user,
args=args,
from_source='api',
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming
)
@@ -62,8 +63,7 @@ class CompletionApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource):
if app_model.mode != 'completion':
raise NotCompletionAppError()
PubHandler.stop(end_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200
@@ -105,7 +105,7 @@ class ChatApi(WebApiResource):
app_model=app_model,
user=end_user,
args=args,
from_source='api',
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming
)
@@ -123,8 +123,7 @@ class ChatApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource):
if app_model.mode != 'chat':
raise NotChatAppError()
PubHandler.stop(end_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200
@@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"

View File

@@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
@@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=end_user,
message_id=message_id,
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
@@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
@@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")

View File

@@ -1,11 +1,15 @@
# -*- coding:utf-8 -*-
import os
from flask_restful import fields, marshal_with
from flask import current_app
from werkzeug.exceptions import Forbidden
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from models.model import Site
from services.feature_service import FeatureService
class AppSiteApi(WebApiResource):
@@ -39,6 +43,8 @@ class AppSiteApi(WebApiResource):
'site': fields.Nested(site_fields),
'model_config': fields.Nested(model_config_fields, allow_null=True),
'plan': fields.String,
'can_replace_logo': fields.Boolean,
'custom_config': fields.Raw(attribute='custom_config'),
}
@marshal_with(app_fields)
@@ -50,7 +56,9 @@ class AppSiteApi(WebApiResource):
if not site:
raise Forbidden()
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id)
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
api.add_resource(AppSiteApi, '/site')
@@ -59,7 +67,7 @@ api.add_resource(AppSiteApi, '/site')
class AppSiteInfo:
"""Class to store site information."""
def __init__(self, tenant, app, site, end_user):
def __init__(self, tenant, app, site, end_user, can_replace_logo):
"""Initialize AppSiteInfo instance."""
self.app_id = app.id
self.end_user_id = end_user
@@ -67,6 +75,16 @@ class AppSiteInfo:
self.site = site
self.model_config = None
self.plan = tenant.plan
self.can_replace_logo = can_replace_logo
if can_replace_logo:
base_url = current_app.config.get('FILES_URL')
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
self.custom_config = {
'remove_webapp_brand': remove_webapp_brand,
'replace_webapp_logo': replace_webapp_logo,
}
if app.enable_site and site.prompt_public:
app_model_config = app.app_model_config

View File

@@ -0,0 +1,101 @@
import logging
from typing import Optional, List
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class AgentLLMCallback(Callback):
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
self.agent_callback = agent_callback
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_before_invoke(
prompt_messages=prompt_messages
)
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
pass
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_after_invoke(
result=result
)
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_error(
error=ex
)

View File

@@ -1,28 +1,49 @@
from typing import List
from typing import List, cast
from langchain.schema import BaseMessage
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
return model_instance.get_num_tokens(to_prompt_messages(messages))
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param llm:
:param model_config:
:param messages:
:return:
"""
llm_max_tokens = model_instance.model_rules.max_tokens.max
completion_max_tokens = model_instance.model_kwargs.max_tokens
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens

View File

@@ -1,4 +1,3 @@
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
@@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.model_manager import ModelInstance
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
@@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
model_config: ModelConfigEntity
class Config:
"""Configuration for this pydantic object."""
@@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
raise e
def real_plan(
self,
@@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.content,
content=result.message.content or "",
additional_kwargs={
'function_call': result.function_call
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
@@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
system_message=system_message,
)
return cls(
model_instance=model_instance,
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,

View File

@@ -1,4 +1,4 @@
from typing import List, Tuple, Any, Union, Sequence, Optional
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
@@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage,
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.model_manager import ModelInstance
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_instance: BaseLLM = None
model_instance: BaseLLM
summary_model_config: ModelConfigEntity = None
model_config: ModelConfigEntity
agent_llm_callback: Optional[AgentLLMCallback] = None
class Config:
"""Configuration for this pydantic object."""
@@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
@@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
system_message=system_message,
)
return cls(
model_instance=model_instance,
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
**kwargs,
)
@@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
:param query:
:return:
"""
original_max_tokens = self.model_instance.model_kwargs.max_tokens
self.model_instance.model_kwargs.max_tokens = 40
original_max_tokens = 0
for parameter_rule in self.model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
self.model_config.parameters['max_tokens'] = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
callbacks=None
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
raise e
function_call = result.function_call
self.model_config.parameters['max_tokens'] = original_max_tokens
self.model_instance.model_kwargs.max_tokens = original_max_tokens
return True if function_call else False
return True if result.message.tool_calls else False
def plan(
self,
@@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.content,
content=result.message.content or "",
additional_kwargs={
'function_call': result.function_call
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
agent_decision = _parse_ai_message(ai_message)
@@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = self.get_message_rest_tokens(
self.model_config,
messages,
**kwargs
)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
@@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_instance.model_provider.provider_name == 'azure_openai':
model = model_instance.base_model_name
if model_config.provider == 'azure_openai':
model = model_config.model
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_instance.base_model_name
model = model_config.credentials.get("base_model_name")
tiktoken_ = _import_tiktoken()
try:

View File

@@ -1,158 +0,0 @@
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@@ -12,9 +12,7 @@ from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.entities.application_entities import ModelConfigEntity
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -69,10 +67,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
@@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
raise e
try:
agent_decision = self.output_parser.parse(full_output)
@@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
@@ -182,7 +180,7 @@ Thought: {agent_scratchpad}
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
@@ -193,7 +191,7 @@ Thought: {agent_scratchpad}
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
@@ -207,7 +205,7 @@ Thought: {agent_scratchpad}
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -221,7 +219,7 @@ Thought: {agent_scratchpad}
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT:
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
@@ -238,10 +236,16 @@ Thought: {agent_scratchpad}
format_instructions=format_instructions,
input_variables=input_variables
)
llm_chain = LLMChain(
model_instance=model_instance,
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser

View File

@@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage,
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -54,7 +55,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_instance: BaseLLM = None
summary_model_config: ModelConfigEntity = None
class Config:
"""Configuration for this pydantic object."""
@@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observatons
callbacks: Callbacks to run.
**kwargs: User inputs.
@@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
prompt_messages = lc_messages_to_prompt_messages(messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
raise e
try:
agent_decision = self.output_parser.parse(full_output)
@@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_instance:
if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
@@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
@@ -229,7 +231,7 @@ Thought: {agent_scratchpad}
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
@@ -243,7 +245,7 @@ Thought: {agent_scratchpad}
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -253,11 +255,12 @@ Thought: {agent_scratchpad}
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT:
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
@@ -275,9 +278,15 @@ Thought: {agent_scratchpad}
input_variables=input_variables,
)
llm_chain = LLMChain(
model_instance=model_instance,
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser

View File

@@ -4,10 +4,10 @@ from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
@@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum):
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_instance: BaseLLM
model_config: ModelConfigEntity
tools: list[BaseTool]
summary_model_instance: BaseLLM = None
memory: Optional[BaseChatMemory] = None
summary_model_config: Optional[ModelConfigEntity] = None
memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
agent_llm_callback: Optional[AgentLLMCallback] = None
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:
@@ -62,34 +65,42 @@ class AgentExecutor:
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
@@ -104,11 +115,11 @@ class AgentExecutor:
def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider,
self.configuration.model_config,
query
)
if not moderation_result:
if moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
@@ -118,7 +129,6 @@ class AgentExecutor:
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
memory=self.configuration.memory,
max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method,
@@ -126,8 +136,8 @@ class AgentExecutor:
)
try:
output = agent_executor.run(query)
except LLMError as ex:
output = agent_executor.run(input=query)
except InvokeError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed")

View File

@@ -0,0 +1,251 @@
import json
import logging
from typing import cast
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.app_runner.app_runner import AppRunner
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
from core.application_queue_manager import ApplicationQueueManager
from core.features.agent_runner import AgentRunnerFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_database import db
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
logger = logging.getLogger(__name__)
class AgentApplicationRunner(AppRunner):
"""
Agent Application Runner
"""
def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Run agent application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError(f"App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=None,
memory=memory
)
# Create MessageChain
message_chain = self._init_message_chain(
message=message,
query=query
)
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_config=app_orchestration_config.model_config,
message=message,
queue_manager=queue_manager,
message_chain=message_chain
)
# init LLM Callback
agent_llm_callback = AgentLLMCallback(
agent_callback=agent_callback
)
agent_runner = AgentRunnerFeature(
tenant_id=application_generate_entity.tenant_id,
app_orchestration_config=app_orchestration_config,
model_config=app_orchestration_config.model_config,
config=app_orchestration_config.agent,
queue_manager=queue_manager,
message=message,
user_id=application_generate_entity.user_id,
agent_llm_callback=agent_llm_callback,
callback=agent_callback,
memory=memory
)
# agent run
result = agent_runner.run(
query=query,
invoke_from=application_generate_entity.invoke_from
)
if result:
self._save_message_chain(
message_chain=message_chain,
output_text=result
)
if (result
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
and app_orchestration_config.prompt_template.simple_prompt_template
):
# Direct output if agent result exists and has pre prompt
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
stream=application_generate_entity.stream,
text=result,
usage=self._get_usage_of_all_agent_thoughts(
model_config=app_orchestration_config.model_config,
message=message
)
)
else:
# As normal LLM run, agent result as context
context = result
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
memory=memory
)
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
)
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == message.id).all())
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model,
model_config.credentials,
all_message_tokens,
all_answer_tokens
)

View File

@@ -0,0 +1,267 @@
import time
from typing import cast, Optional, List, Tuple, Generator, Union
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform
from models.model import App
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileObj],
query: Optional[str] = None) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
# get prompt messages without memory and context
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query
)
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
prompt_messages: List[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
model_config.parameters[parameter_rule.name] = max_tokens
def originze_prompt_messages(self, app_record: App,
model_config: ModelConfigEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileObj],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
-> Tuple[List[PromptMessage], Optional[List[str]]]:
"""
Organize prompt messages
:param context:
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:param memory: memory
:return:
"""
prompt_transform = PromptTransform()
# get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
prompt_messages, stop = prompt_transform.get_prompt(
app_mode=app_record.mode,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else '',
files=files,
context=context,
memory=memory,
model_config=model_config
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=app_record.mode,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop
def direct_output(self, queue_manager: ApplicationQueueManager,
app_orchestration_config: AppOrchestrationConfigEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param app_orchestration_config: app orchestration config
:param prompt_messages: prompt messages
:param text: text
:param stream: stream
:param usage: usage
:return:
"""
if stream:
index = 0
for token in text:
queue_manager.publish_chunk_message(LLMResultChunk(
model=app_orchestration_config.model_config.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=token)
)
))
index += 1
time.sleep(0.01)
queue_manager.publish_message_end(
llm_result=LLMResult(
model=app_orchestration_config.model_config.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage()
)
)
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
queue_manager: ApplicationQueueManager,
stream: bool) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param stream: stream
:return:
"""
if not stream:
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager
)
else:
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager
)
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
queue_manager: ApplicationQueueManager) -> None:
"""
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:return:
"""
queue_manager.publish_message_end(
llm_result=invoke_result
)
def _handle_invoke_result_stream(self, invoke_result: Generator,
queue_manager: ApplicationQueueManager) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:return:
"""
model = None
prompt_messages = []
text = ''
usage = None
for result in invoke_result:
queue_manager.publish_chunk_message(result)
text += result.delta.message.content
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
llm_result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage
)
queue_manager.publish_message_end(
llm_result=llm_result
)

View File

@@ -0,0 +1,363 @@
import logging
from typing import Tuple, Optional
from core.app_runner.app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
from core.application_queue_manager import ApplicationQueueManager
from core.features.annotation_reply import AnnotationReplyFeature
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.external_data_fetch import ExternalDataFetchFeature
from core.features.hosting_moderation import HostingModerationFeature
from core.features.moderation import ModerationFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.moderation.base import ModerationException
from core.prompt.prompt_transform import AppMode
from extensions.ext_database import db
from models.model import Conversation, Message, App, MessageAnnotation
logger = logging.getLogger(__name__)
class BasicApplicationRunner(AppRunner):
"""
Basic Application Runner
"""
def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError(f"App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
memory=memory
)
# moderation
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=application_generate_entity.tenant_id,
app_orchestration_config_entity=app_orchestration_config,
inputs=inputs,
query=query,
)
except ModerationException as e:
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
)
return
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id
)
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get context from datasets
context = None
if app_orchestration_config.dataset:
context = self.retrieve_dataset_context(
tenant_id=app_record.tenant_id,
app_record=app_record,
queue_manager=queue_manager,
model_config=app_orchestration_config.model_config,
show_retrieve_source=app_orchestration_config.show_retrieve_source,
dataset_config=app_orchestration_config.dataset,
message=message,
inputs=inputs,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
memory=memory
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
memory=memory
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
)
def moderation_for_inputs(self, app_id: str,
tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity,
inputs: dict,
query: str) -> Tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_orchestration_config_entity: app orchestration config entity
:param inputs: inputs
:param query: query
:return:
"""
moderation_feature = ModerationFeature()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_orchestration_config_entity=app_orchestration_config_entity,
inputs=inputs,
query=query,
)
def query_app_annotations_to_reply(self, app_record: App,
message: Message,
query: str,
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record,
message=message,
query=query,
user_id=user_id,
invoke_from=invoke_from
)
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
external_data_fetch_feature = ExternalDataFetchFeature()
return external_data_fetch_feature.fetch(
tenant_id=tenant_id,
app_id=app_id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
def retrieve_dataset_context(self, tenant_id: str,
app_record: App,
queue_manager: ApplicationQueueManager,
model_config: ModelConfigEntity,
dataset_config: DatasetEntity,
show_retrieve_source: bool,
message: Message,
inputs: dict,
query: str,
user_id: str,
invoke_from: InvokeFrom,
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
"""
Retrieve dataset context
:param tenant_id: tenant id
:param app_record: app record
:param queue_manager: queue manager
:param model_config: model config
:param dataset_config: dataset config
:param show_retrieve_source: show retrieve source
:param message: message
:param inputs: inputs
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:param memory: memory
:return:
"""
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
user_id,
invoke_from
)
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
and dataset_config.retrieve_config.query_variable):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrievalFeature()
return dataset_retrieval.retrieve(
tenant_id=tenant_id,
model_config=model_config,
config=dataset_config,
query=query,
invoke_from=invoke_from,
show_retrieve_source=show_retrieve_source,
hit_callback=hit_callback,
memory=memory
)
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
prompt_messages: list[PromptMessage]) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param prompt_messages: prompt messages
:return:
"""
hosting_moderation_feature = HostingModerationFeature()
moderation_result = hosting_moderation_feature.check(
application_generate_entity=application_generate_entity,
prompt_messages=prompt_messages
)
if moderation_result:
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
prompt_messages=prompt_messages,
text="I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream
)
return moderation_result

View File

@@ -0,0 +1,483 @@
import json
import logging
import time
from typing import Union, Generator, cast, Optional
from pydantic import BaseModel
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
from core.entities.application_entities import ApplicationGenerateEntity
from core.application_queue_manager import ApplicationQueueManager
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
AnnotationReplyEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from models.model import Message, Conversation, MessageAgentThought
from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__)
class TaskState(BaseModel):
"""
TaskState entity
"""
llm_result: LLMResult
metadata: dict = {}
class GenerateTaskPipeline:
"""
GenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._conversation = conversation
self._message = message
self._task_state = TaskState(
llm_result=LLMResult(
model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
prompt_messages=[],
message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage()
)
)
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
def process(self, stream: bool) -> Union[dict, Generator]:
"""
Process generate task pipeline.
:return:
"""
if stream:
return self._process_stream_response()
else:
return self._process_blocking_response()
def _process_blocking_response(self) -> dict:
"""
Process blocking response.
:return:
"""
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueueErrorEvent):
raise self._handle_error(event)
elif isinstance(event, QueueRetrieverResourcesEvent):
self._task_state.metadata['retriever_resources'] = event.retriever_resources
elif isinstance(event, AnnotationReplyEvent):
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata['annotation_reply'] = {
'id': annotation.id,
'account': {
'id': annotation.account_id,
'name': account.name if account else 'Dify user'
}
}
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
self._task_state.llm_result.prompt_messages
)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
[self._task_state.llm_result.message]
)
credentials = model_config.credentials
# transform usage
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,
prompt_tokens,
completion_tokens
)
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
completion=self._task_state.llm_result.message.content,
public_event=False
)
# Save message
self._save_message(event.llm_result)
response = {
'event': 'message',
'task_id': self._application_generate_entity.task_id,
'id': self._message.id,
'mode': self._conversation.mode,
'answer': event.llm_result.message.content,
'metadata': {},
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
if self._task_state.metadata:
response['metadata'] = self._task_state.metadata
return response
else:
continue
def _process_stream_response(self) -> Generator:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
event = message.event
if isinstance(event, QueueErrorEvent):
raise self._handle_error(event)
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
self._task_state.llm_result.prompt_messages
)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
[self._task_state.llm_result.message]
)
credentials = model_config.credentials
# transform usage
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,
prompt_tokens,
completion_tokens
)
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
completion=self._task_state.llm_result.message.content,
public_event=False
)
self._output_moderation_handler = None
replace_response = {
'event': 'message_replace',
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
'answer': self._task_state.llm_result.message.content,
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
replace_response['conversation_id'] = self._conversation.id
yield self._yield_response(replace_response)
# Save message
self._save_message(self._task_state.llm_result)
response = {
'event': 'message_end',
'task_id': self._application_generate_entity.task_id,
'id': self._message.id,
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
if self._task_state.metadata:
response['metadata'] = self._task_state.metadata
yield self._yield_response(response)
elif isinstance(event, QueueRetrieverResourcesEvent):
self._task_state.metadata['retriever_resources'] = event.retriever_resources
elif isinstance(event, AnnotationReplyEvent):
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata['annotation_reply'] = {
'id': annotation.id,
'account': {
'id': annotation.account_id,
'name': account.name if account else 'Dify user'
}
}
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
agent_thought = (
db.session.query(MessageAgentThought)
.filter(MessageAgentThought.id == event.agent_thought_id)
.first()
)
if agent_thought:
response = {
'event': 'agent_thought',
'id': agent_thought.id,
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
'position': agent_thought.position,
'thought': agent_thought.thought,
'tool': agent_thought.tool,
'tool_input': agent_thought.tool_input,
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
yield self._yield_response(response)
elif isinstance(event, QueueMessageEvent):
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:
continue
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._queue_manager.publish_chunk_message(LLMResultChunk(
model=self._task_state.llm_result.model,
prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
)
))
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
continue
else:
self._output_moderation_handler.append_new_token(delta_text)
self._task_state.llm_result.message.content += delta_text
response = self._handle_chunk(delta_text)
yield self._yield_response(response)
elif isinstance(event, QueueMessageReplaceEvent):
response = {
'event': 'message_replace',
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
'answer': event.text,
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
yield self._yield_response(response)
elif isinstance(event, QueuePingEvent):
yield "event: ping\n\n"
else:
continue
def _save_message(self, llm_result: LLMResult) -> None:
"""
Save message.
:param llm_result: llm result
:return:
"""
usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
if llm_result.message.content else ''
self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price
db.session.commit()
message_was_created.send(
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
)
def _handle_chunk(self, text: str) -> dict:
"""
Handle completed event.
:param text: text
:return:
"""
response = {
'event': 'message',
'id': self._message.id,
'task_id': self._application_generate_entity.task_id,
'message_id': self._message.id,
'answer': text,
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
response['conversation_id'] = self._conversation.id
return response
def _handle_error(self, event: QueueErrorEvent) -> Exception:
"""
Handle error event.
:param event: event
:return:
"""
logger.debug("error: %s", event.error)
e = event.error
if isinstance(e, InvokeAuthorizationError):
return InvokeAuthorizationError('Incorrect API key provided')
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
return e
else:
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
def _yield_response(self, response: dict) -> str:
"""
Yield response.
:param response: response
:return:
"""
return "data: " + json.dumps(response) + "\n\n"
def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
"""
Prompt messages to prompt for saving.
:param prompt_messages: prompt messages
:return:
"""
prompts = []
if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
role = 'user'
elif prompt_message.role == PromptMessageRole.ASSISTANT:
role = 'assistant'
elif prompt_message.role == PromptMessageRole.SYSTEM:
role = 'system'
else:
continue
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
prompts.append({
"role": role,
"text": text,
"files": files
})
else:
prompts.append({
"role": 'user',
"text": prompt_messages[0].content
})
return prompts
def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
"""
Init output moderation.
:return:
"""
app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
if sensitive_word_avoidance:
return OutputModerationHandler(
tenant_id=self._application_generate_entity.tenant_id,
app_id=self._application_generate_entity.app_id,
rule=ModerationRule(
type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config
),
on_message_replace_func=self._queue_manager.publish_message_replace
)

View File

@@ -0,0 +1,138 @@
import logging
import threading
import time
from typing import Any, Optional, Dict
from flask import current_app, Flask
from pydantic import BaseModel
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory
logger = logging.getLogger(__name__)
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
rule: ModerationRule
on_message_replace_func: Any
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
is_final_chunk: bool = False
final_output: Optional[str] = None
class Config:
arbitrary_types_allowed = True
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
})
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logger.error("Moderation Output error: %s", e)
return None

View File

@@ -0,0 +1,655 @@
import json
import logging
import threading
import uuid
from typing import cast, Optional, Any, Union, Generator, Tuple
from flask import Flask, current_app
from pydantic import ValidationError
from core.app_runner.agent_app_runner import AgentApplicationRunner
from core.app_runner.basic_app_runner import BasicApplicationRunner
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
from core.entities.model_entities import ModelStatus
from core.file.file_obj import FileObj
from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser
from core.provider_manager import ProviderManager
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, Conversation, Message, MessageFile, App
logger = logging.getLogger(__name__)
class ApplicationManager:
"""
This class is responsible for managing application
"""
def generate(self, tenant_id: str,
app_id: str,
app_model_config_id: str,
app_model_config_dict: dict,
app_model_config_override: bool,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
inputs: dict[str, str],
query: Optional[str] = None,
files: Optional[list[FileObj]] = None,
conversation: Optional[Conversation] = None,
stream: bool = False,
extras: Optional[dict[str, Any]] = None) \
-> Union[dict, Generator]:
"""
Generate App response.
:param tenant_id: workspace ID
:param app_id: app ID
:param app_model_config_id: app model config id
:param app_model_config_dict: app model config dict
:param app_model_config_override: app model config override
:param user: account or end user
:param invoke_from: invoke from source
:param inputs: inputs
:param query: query
:param files: file obj list
:param conversation: conversation
:param stream: is stream
:param extras: extras
"""
# init task id
task_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = ApplicationGenerateEntity(
task_id=task_id,
tenant_id=tenant_id,
app_id=app_id,
app_model_config_id=app_model_config_id,
app_model_config_dict=app_model_config_dict,
app_orchestration_config_entity=self._convert_from_app_model_config_dict(
tenant_id=tenant_id,
app_model_config_dict=app_model_config_dict
),
app_model_config_override=app_model_config_override,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else inputs,
query=query.replace('\x00', '') if query else None,
files=files if files else [],
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = ApplicationQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if application_generate_entity.app_orchestration_config_entity.agent:
# agent app
runner = AgentApplicationRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
else:
# basic app
runner = BasicApplicationRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e)
finally:
db.session.remove()
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message,
stream: bool = False) -> Union[dict, Generator]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = GenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
try:
return generate_task_pipeline.process(stream=stream)
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise ConversationTaskStoppedException()
else:
logger.exception(e)
raise e
finally:
db.session.remove()
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity:
"""
Convert app model config dict to entity.
:param tenant_id: tenant ID
:param app_model_config_dict: app model config dict
:raises ProviderTokenNotInitError: provider token not init error
:return: app orchestration config entity
"""
properties = {}
copy_app_model_config_dict = app_model_config_dict.copy()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=copy_app_model_config_dict['model']['provider'],
model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
model_name = copy_app_model_config_dict['model']['name']
model_type_instance = provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=copy_app_model_config_dict['model']['name']
)
if model_credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=copy_app_model_config_dict['model']['name'],
model_type=ModelType.LLM
)
if provider_model is None:
model_name = copy_app_model_config_dict['model']['name']
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = copy_app_model_config_dict['model'].get('completion_params')
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = copy_app_model_config_dict['model'].get('mode')
if not model_mode:
mode_enum = model_type_instance.get_model_mode(
model=copy_app_model_config_dict['model']['name'],
credentials=model_credentials
)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(
copy_app_model_config_dict['model']['name'],
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
properties['model_config'] = ModelConfigEntity(
provider=copy_app_model_config_dict['model']['provider'],
model=copy_app_model_config_dict['model']['name'],
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
# prompt template
prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
properties['prompt_template'] = PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else:
advanced_chat_prompt_template = None
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
}
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
**completion_prompt_template_params
)
properties['prompt_template'] = PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
)
# external data variables
properties['external_data_variables'] = []
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue
properties['external_data_variables'].append(
ExternalDataVariableEntity(
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
)
)
# show retrieve source
show_retrieve_source = False
retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
if retriever_resource_dict:
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
show_retrieve_source = True
properties['show_retrieve_source'] = show_retrieve_source
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
'enabled']:
agent_dict = copy_app_model_config_dict.get('agent_mode')
if agent_dict['strategy'] in ['router', 'react_router']:
dataset_ids = []
for tool in agent_dict.get('tools', []):
key = list(tool.keys())[0]
if key != 'dataset':
continue
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item['id']
dataset_ids.append(dataset_id)
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single':
properties['dataset'] = DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
single_strategy=agent_dict['strategy']
)
)
else:
properties['dataset'] = DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get('top_k'),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model')
)
)
else:
if agent_dict['strategy'] == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
strategy = AgentEntity.Strategy.FUNCTION_CALLING
agent_tools = []
for tool in agent_dict.get('tools', []):
key = list(tool.keys())[0]
tool_item = tool[key]
agent_tool_properties = {
"tool_id": key
}
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
agent_tool_properties["config"] = tool_item
agent_tools.append(AgentToolEntity(**agent_tool_properties))
properties['agent'] = AgentEntity(
provider=properties['model_config'].provider,
model=properties['model_config'].model,
strategy=strategy,
tools=agent_tools
)
# file upload
file_upload_dict = copy_app_model_config_dict.get('file_upload')
if file_upload_dict:
if 'image' in file_upload_dict and file_upload_dict['image']:
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
properties['file_upload'] = FileUploadEntity(
image_config={
'number_limits': file_upload_dict['image']['number_limits'],
'detail': file_upload_dict['image']['detail'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
}
)
# opening statement
properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
# suggested questions after answer
suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict:
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
properties['suggested_questions_after_answer'] = True
# more like this
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
if more_like_this_dict:
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
# speech to text
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
if speech_to_text_dict:
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
properties['speech_to_text'] = True
# sensitive word avoidance
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
if sensitive_word_avoidance_dict:
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
)
return AppOrchestrationConfigEntity(**properties)
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
-> Tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity
:return:
"""
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_schema = model_type_instance.get_model_schema(
model=app_orchestration_config_entity.model_config.model,
credentials=app_orchestration_config_entity.model_config.credentials
)
app_record = (db.session.query(App)
.filter(App.id == application_generate_entity.app_id).first())
app_mode = app_record.mode
# get from source
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
from_source = 'api'
end_user_id = application_generate_entity.user_id
else:
from_source = 'console'
account_id = application_generate_entity.user_id
override_model_configs = None
if application_generate_entity.app_model_config_override:
override_model_configs = application_generate_entity.app_model_config_dict
introduction = ''
if app_mode == 'chat':
# get conversation introduction
introduction = self._get_conversation_introduction(application_generate_entity)
if not application_generate_entity.conversation_id:
conversation = Conversation(
app_id=app_record.id,
app_model_config_id=application_generate_entity.app_model_config_id,
model_provider=app_orchestration_config_entity.model_config.provider,
model_id=app_orchestration_config_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_mode,
name='New conversation',
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status='normal',
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.commit()
else:
conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == application_generate_entity.conversation_id,
Conversation.app_id == app_record.id
).first()
)
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
message = Message(
app_id=app_record.id,
model_provider=app_orchestration_config_entity.model_config.provider,
model_id=app_orchestration_config_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query or "",
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency=currency,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
agent_based=app_orchestration_config_entity.agent is not None
)
db.session.add(message)
db.session.commit()
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if account_id else 'end_user'),
created_by=account_id or end_user_id,
)
db.session.add(message_file)
db.session.commit()
return conversation, message
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
"""
Get conversation introduction
:param application_generate_entity: application generate entity
:return: conversation introduction
"""
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
introduction = app_orchestration_config_entity.opening_statement
if introduction:
try:
inputs = application_generate_entity.inputs
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
return introduction
def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
:return: conversation
"""
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
return conversation
def _get_message(self, message_id: str) -> Message:
"""
Get message by message id
:param message_id: message id
:return: message
"""
message = (
db.session.query(Message)
.filter(Message.id == message_id)
.first()
)
return message

View File

@@ -0,0 +1,228 @@
import queue
import time
from typing import Generator, Any
from sqlalchemy.orm import DeclarativeMeta
from core.entities.application_entities import InvokeFrom
from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
QueueMessageEvent, QueueMessage, AnnotationReplyEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from extensions.ext_redis import redis_client
from models.model import MessageAgentThought
class ApplicationQueueManager:
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
conversation_id: str,
app_mode: str,
message_id: str) -> None:
if not user_id:
raise ValueError("user is required")
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
q = queue.Queue()
self._q = q
def listen(self) -> Generator:
"""
Listen to queue
:return:
"""
# wait for 10 minutes to stop listen
listen_timeout = 600
start_time = time.time()
last_ping_time = 0
while True:
try:
message = self._q.get(timeout=1)
if message is None:
break
yield message
except queue.Empty:
continue
finally:
elapsed_time = time.time() - start_time
if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
self.stop_listen()
if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent())
last_ping_time = elapsed_time // 10
def stop_listen(self) -> None:
"""
Stop listen to queue
:return:
"""
self._q.put(None)
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
"""
Publish chunk message to channel
:param chunk: chunk
:return:
"""
self.publish(QueueMessageEvent(
chunk=chunk
))
def publish_message_replace(self, text: str) -> None:
"""
Publish message replace
:param text: text
:return:
"""
self.publish(QueueMessageReplaceEvent(
text=text
))
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
"""
Publish retriever resources
:return:
"""
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
def publish_annotation_reply(self, message_annotation_id: str) -> None:
"""
Publish annotation reply
:param message_annotation_id: message annotation id
:return:
"""
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
def publish_message_end(self, llm_result: LLMResult) -> None:
"""
Publish message end
:param llm_result: llm result
:return:
"""
self.publish(QueueMessageEndEvent(llm_result=llm_result))
self.stop_listen()
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
"""
Publish agent thought
:param message_agent_thought: message agent thought
:return:
"""
self.publish(QueueAgentThoughtEvent(
agent_thought_id=message_agent_thought.id
))
def publish_error(self, e) -> None:
"""
Publish error
:param e: error
:return:
"""
self.publish(QueueErrorEvent(
error=e
))
self.stop_listen()
def publish(self, event: AppQueueEvent) -> None:
"""
Publish event to queue
:param event:
:return:
"""
self._check_for_sqlalchemy_models(event.dict())
message = QueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
)
self._q.put(message)
if isinstance(event, QueueStopEvent):
self.stop_listen()
@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
"""
Set task stop flag
:return:
"""
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
if result != f"{user_prefix}-{user_id}":
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
def _is_stopped(self) -> bool:
"""
Check if task is stopped
:return:
"""
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
result = redis_client.get(stopped_cache_key)
if result is not None:
redis_client.delete(stopped_cache_key)
return True
return False
@classmethod
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
"""
Generate task belong cache key
:param task_id: task id
:return:
"""
return f"generate_task_belong:{task_id}"
@classmethod
def _generate_stopped_cache_key(cls, task_id: str) -> str:
"""
Generate stopped cache key
:param task_id: task id
:return:
"""
return f"generate_task_stopped:{task_id}"
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for key, value in data.items():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed.")
class ConversationTaskStoppedException(Exception):
pass

View File

@@ -2,30 +2,40 @@ import json
import logging
import time
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union, Optional, cast
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_database import db
from models.model import MessageChain, MessageAgentThought, Message
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_config: ModelConfigEntity,
queue_manager: ApplicationQueueManager,
message: Message,
message_chain: MessageChain) -> None:
"""Initialize callback handler."""
self.model_instance = model_instance
self.conversation_message_task = conversation_message_task
self.model_config = model_config
self.queue_manager = queue_manager
self.message = message
self.message_chain = message_chain
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
self.current_chain = None
@property
def agent_loops(self) -> List[AgentLoop]:
@@ -46,65 +56,60 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks."""
return True
def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
if result.usage:
self._current_loop.prompt_tokens = result.usage.prompt_tokens
else:
self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
model=self.model_config.model,
credentials=self.model_config.credentials,
prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
)
completion_message = result.message
if completion_message.tool_calls:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.tool_calls})
else:
self._current_loop.completion = completion_message.content
if result.usage:
self._current_loop.completion_tokens = result.usage.completion_tokens
else:
self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
model=self.model_config.model,
credentials=self.model_config.credentials,
prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
)
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
pass
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
# serialized={'name': 'OpenAI'}
# prompts=['Answer the following questions...\nThought:']
# kwargs={}
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt=prompts[0],
status='llm_started',
started_at=time.perf_counter()
)
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
# kwargs={}
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
if 'function_call' in completion_message.additional_kwargs:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
else:
self._current_loop.completion = response.generations[0][0].text
else:
self._current_loop.completion = completion_generation.text
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if completion is not None:
self._current_loop.completion = completion
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
self._message_agent_thought = self._init_agent_thought()
def on_tool_end(
self,
@@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instance, self._current_loop
)
self._complete_agent_thought(self._message_agent_thought)
self._agent_loops.append(self._current_loop)
self._current_loop = None
@@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._current_loop.thought = '[DONE]'
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
self._message_agent_thought = self._init_agent_thought()
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instance, self._current_loop
)
self._complete_agent_thought(self._message_agent_thought)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish'
def _init_agent_thought(self) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=self.message_chain.id,
position=self._current_loop.position,
thought=self._current_loop.thought,
tool=self._current_loop.tool_name,
tool_input=self._current_loop.tool_input,
message=self._current_loop.prompt,
message_price_unit=0,
answer=self._current_loop.completion,
answer_price_unit=0,
created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
created_by=(self.message.from_account_id
if self.message.from_source == 'console' else self.message.from_end_user_id)
)
db.session.add(message_agent_thought)
db.session.commit()
self.queue_manager.publish_agent_thought(message_agent_thought)
return message_agent_thought
def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
loop_message_tokens = self._current_loop.prompt_tokens
loop_answer_tokens = self._current_loop.completion_tokens
# transform usage
llm_usage = self.model_type_instance._calc_response_usage(
self.model_config.model,
self.model_config.credentials,
loop_message_tokens,
loop_answer_tokens
)
message_agent_thought.observation = self._current_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
message_agent_thought.latency = self._current_loop.latency
message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
message_agent_thought.total_price = llm_usage.total_price
message_agent_thought.currency = llm_usage.currency
db.session.commit()

View File

@@ -1,74 +0,0 @@
import json
import logging
from json import JSONDecodeError
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.conversation_message_task import ConversationMessageTask
class DatasetToolCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.queries = []
self.conversation_message_task = conversation_message_task
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return True
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return True
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
tool_name: str = serialized.get('name')
dataset_id = tool_name.removeprefix('dataset-')
try:
input_dict = json.loads(input_str.replace("'", "\""))
query = input_dict.get('query')
except JSONDecodeError:
query = input_str
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.debug("Dataset tool on_llm_error: %s", error)

View File

@@ -1,16 +0,0 @@
from pydantic import BaseModel
class ChainResult(BaseModel):
type: str = None
prompt: dict = None
completion: dict = None
status: str = 'chain_started'
completed: bool = False
started_at: float = None
completed_at: float = None
agent_result: dict = None
"""only when type is 'AgentExecutor'"""

View File

@@ -1,6 +0,0 @@
from pydantic import BaseModel
class DatasetQueryObj(BaseModel):
dataset_id: str = None
query: str = None

View File

@@ -1,8 +0,0 @@
from pydantic import BaseModel
class LLMMessage(BaseModel):
prompt: str = ''
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0

View File

@@ -1,17 +1,44 @@
from typing import List
from typing import List, Union
from langchain.schema import Document
from core.conversation_message_task import ConversationMessageTask
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from extensions.ext_database import db
from models.dataset import DocumentSegment
from models.dataset import DocumentSegment, DatasetQuery
from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
self.conversation_message_task = conversation_message_task
def __init__(self, queue_manager: ApplicationQueueManager,
app_id: str,
message_id: str,
user_id: str,
invoke_from: InvokeFrom) -> None:
self._queue_manager = queue_manager
self._app_id = app_id
self._message_id = message_id
self._user_id = user_id
self._invoke_from = invoke_from
def on_query(self, query: str, dataset_id: str) -> None:
"""
Handle query.
"""
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source_app_id=self._app_id,
created_by_role=('account'
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=self._user_id
)
db.session.add(dataset_query)
db.session.commit()
def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end."""
@@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler:
def return_retriever_resource_info(self, resource: List):
"""Handle return_retriever_resource_info."""
self.conversation_message_task.on_dataset_query_finish(resource)
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self._user_id
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish_retriever_resources(resource)

View File

@@ -1,284 +0,0 @@
import logging
import threading
import time
from typing import Any, Dict, List, Union, Optional
from flask import Flask, current_app
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage
from pydantic import BaseModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
ImagePromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.moderation.base import ModerationOutputsResult, ModerationAction
from core.moderation.factory import ModerationFactory
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, model_instance: BaseLLM,
conversation_message_task: ConversationMessageTask):
self.model_instance = model_instance
self.llm_message = LLMMessage()
self.start_at = None
self.conversation_message_task = conversation_message_task
self.output_moderation_handler = None
self.init_output_moderation()
def init_output_moderation(self):
app_model_config = self.conversation_message_task.app_model_config
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
self.output_moderation_handler = OutputModerationHandler(
tenant_id=self.conversation_message_task.tenant_id,
app_id=self.conversation_message_task.app.id,
rule=ModerationRule(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config")
),
on_message_replace_func=self.conversation_message_task.on_message_replace
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
real_prompts = []
for message in messages[0]:
if message.type == 'human':
role = 'user'
elif message.type == 'ai':
role = 'assistant'
else:
role = 'system'
real_prompts.append({
"role": role,
"text": message.content,
"files": [{
"type": file.type.value,
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
completion=response.generations[0][0].text,
public_event=True if self.conversation_message_task.streaming else False
)
else:
self.llm_message.completion = response.generations[0][0].text
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(self.llm_message.completion)
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
if 'completion_tokens' in response.llm_output['token_usage']:
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
ex = ConversationTaskInterruptException()
self.on_llm_error(error=ex)
raise ex
try:
self.conversation_message_task.append_message_text(token)
self.llm_message.completion += token
if self.output_moderation_handler:
self.output_moderation_handler.append_new_token(token)
except ConversationTaskStoppedException as ex:
self.on_llm_error(error=ex)
raise ex
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
if isinstance(error, ConversationTaskInterruptException):
self.llm_message.completion = self.output_moderation_handler.get_final_output()
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message)
else:
logging.debug("on_llm_error: %s", error)
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
rule: ModerationRule
on_message_replace_func: Any
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
is_final_chunk: bool = False
final_output: Optional[str] = None
class Config:
arbitrary_types_allowed = True
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
})
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logging.error("Moderation Output error: %s", e)
return None

View File

@@ -1,76 +0,0 @@
import logging
import time
from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult
from core.conversation_message_task import ConversationMessageTask
class MainChainGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self._current_chain_result = None
self._current_chain_message = None
self.conversation_message_task = conversation_message_task
self.agent_callback = None
def clear_chain_results(self) -> None:
self._current_chain_result = None
self._current_chain_message = None
if self.agent_callback:
self.agent_callback.current_chain = None
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return True
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return True
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
if not self._current_chain_result:
chain_type = serialized['id'][-1]
if chain_type:
self._current_chain_result = ChainResult(
type=chain_type,
prompt=inputs,
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
if self.agent_callback:
self.agent_callback.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
self._current_chain_result.status = 'chain_ended'
self._current_chain_result.completion = outputs
self._current_chain_result.completed = True
self._current_chain_result.completed_at = time.perf_counter()
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
self.clear_chain_results()
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.debug("Dataset tool on_chain_error: %s", error)
self.clear_chain_results()

View File

@@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
try:
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
except ValueError:
thought = ''
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
print_text("\n[on_agent_action]\n" + log + "\n", color='green')

View File

@@ -5,15 +5,19 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import LLMResult, Generation
from langchain.schema.language_model import BaseLanguageModel
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity
from core.model_manager import ModelInstance
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):
model_instance: BaseLLM
model_config: ModelConfigEntity
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
parameters: Dict[str, Any] = {}
agent_llm_callback: Optional[AgentLLMCallback] = None
def generate(
self,
@@ -23,14 +27,23 @@ class LLMChain(LCLLMChain):
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
stop=stop
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
stream=False,
stop=stop,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
model_parameters=self.parameters
)
generations = [
[Generation(text=result.content)]
[Generation(text=result.message.content)]
]
return LLMResult(generations=generations)

View File

@@ -1,417 +0,0 @@
import concurrent
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, Union, Tuple
from flask import current_app, Flask
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
class Completion:
@classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True):
"""
errors: ProviderTokenNotInitError
"""
query = PromptTemplateParser.remove_template_variables(query)
memory = None
if conversation:
# get memory of conversation (read-only)
memory = cls.get_memory_from_conversation(
tenant_id=app.tenant_id,
app_model_config=app_model_config,
conversation=conversation,
return_messages=False
)
inputs = conversation.inputs
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model_config=app_model_config.model_dict,
streaming=streaming
)
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
conversation=conversation,
is_override=is_override,
inputs=inputs,
query=query,
files=files,
streaming=streaming,
model_instance=final_model_instance,
auto_generate_name=auto_generate_name
)
prompt_message_files = [file.prompt_message_file for file in files]
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
model_instance=final_model_instance,
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files
)
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=app.tenant_id,
app_model_config=app_model_config
)
try:
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
try:
# process sensitive_word_avoidance
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
except ModerationException as e:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=str(e)
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
inputs = cls.fill_in_inputs_from_external_data_tools(
tenant_id=app.tenant_id,
app_id=app.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
tenant_id=app.tenant_id,
retriever_from=retriever_from
)
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
# run agent executor
agent_execute_result = None
if query_for_agent and agent_executor:
should_use_agent = agent_executor.should_use_agent(query_for_agent)
if should_use_agent:
agent_execute_result = agent_executor.run(query_for_agent)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# run the final llm
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=fake_response
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
return
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
return inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDED:
inputs = moderation_result.inputs
query = moderation_result.query
return inputs, query
@classmethod
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
inputs: dict, query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
# Group tools by type and config
grouped_tools = {}
for tool in external_data_tools:
if not tool.get("enabled"):
continue
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
grouped_tools.setdefault(tool_key, []).append(tool)
results = {}
with ThreadPoolExecutor() as executor:
futures = {}
for tool in external_data_tools:
if not tool.get("enabled"):
continue
future = executor.submit(
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
inputs, query
)
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
tool_variable, result = future.result()
results[tool_variable] = result
inputs.update(results)
return inputs
@classmethod
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
with flask_app.app_context():
tool_variable = external_data_tool.get("variable")
tool_type = external_data_tool.get("type")
tool_config = external_data_tool.get("config")
external_data_tool_factory = ExternalDataToolFactory(
name=tool_type,
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
)
# query external data tool
result = external_data_tool_factory.query(
inputs=inputs,
query=query
)
return tool_variable, result
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
inputs: dict,
files: List[PromptMessageFile],
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
prompt_transform = PromptTransform()
# get llm prompt
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = prompt_transform.get_prompt(
app_mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
files=files,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
files=files,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
)
model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens(
model_instance=model_instance,
prompt_messages=prompt_messages,
)
response = model_instance.run(
messages=prompt_messages,
stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
return response
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
@classmethod
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
conversation: Conversation,
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
# only for calc token in memory
memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=tenant_id,
model_config=app_model_config.model_dict
)
# use llm config from conversation
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
conversation=conversation,
model_instance=memory_model_instance,
max_token_limit=kwargs.get("max_token_limit", 2048),
memory_key=kwargs.get("memory_key", "chat_history"),
return_messages=kwargs.get("return_messages", True),
input_key=kwargs.get("input_key", "input"),
output_key=kwargs.get("output_key", "output"),
message_limit=kwargs.get("message_limit", 10),
)
return memory
@classmethod
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
if model_limited_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_transform = PromptTransform()
# get prompt without memory and context
if app_model_config.prompt_type == 'simple':
prompt_messages, _ = prompt_transform.get_prompt(
app_mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
files=files,
context=None,
memory=None,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
files=files,
context=None,
memory=None,
model_instance=model_instance
)
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
if model_limited_tokens is None:
return
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
# update model instance max tokens
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)

View File

@@ -1,489 +0,0 @@
import json
import time
from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.file.file_obj import FileObj
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
MessageChain, DatasetRetrieverResource, MessageFile
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, files: List[FileObj], streaming: bool,
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
auto_generate_name: bool = True):
self.start_at = time.perf_counter()
self.task_id = task_id
self.app = app
self.tenant_id = app.tenant_id
self.app_model_config = app_model_config
self.is_override = is_override
self.user = user
self.inputs = inputs
self.query = query
self.files = files
self.streaming = streaming
self.conversation = conversation
self.is_new_conversation = False
self.model_instance = model_instance
self.message = None
self.retriever_resource = None
self.auto_generate_name = auto_generate_name
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name')
self.mode = app.mode
self.init()
self._pub_handler = PubHandler(
user=self.user,
task_id=self.task_id,
message=self.message,
conversation=self.conversation,
chain_pub=False, # disabled currently
agent_thought_pub=True
)
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = self.app_model_config.to_dict()
introduction = ''
system_instruction = ''
system_instruction_tokens = 0
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
if self.app_model_config.pre_prompt:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=self.provider_name,
model_name=self.model_name
)
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app.id,
app_model_config_id=self.app_model_config.id,
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
name='New conversation',
inputs=self.inputs,
introduction=introduction,
system_instruction=system_instruction,
system_instruction_tokens=system_instruction_tokens,
status='normal',
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
)
db.session.add(self.conversation)
db.session.commit()
self.message = Message(
app_id=self.app.id,
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=self.conversation.id,
inputs=self.inputs,
query=self.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency=self.model_instance.get_currency(),
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
)
db.session.add(self.message)
db.session.commit()
for file in self.files:
message_file = MessageFile(
message_id=self.message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_file)
db.session.commit()
def append_message_text(self, text: str):
if text is not None:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = time.perf_counter() - self.start_at
self.message.total_price = total_price
db.session.commit()
message_was_created.send(
self.message,
conversation=self.conversation,
is_first_message=self.is_new_conversation,
auto_generate_name=self.auto_generate_name
)
if not by_stopped:
self.end()
def init_chain(self, chain_result: ChainResult):
message_chain = MessageChain(
message_id=self.message.id,
type=chain_result.type,
input=json.dumps(chain_result.prompt),
output=''
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
message_chain.output = json.dumps(chain_result.completion)
db.session.commit()
self._pub_handler.pub_chain(message_chain)
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
message=agent_loop.prompt,
message_price_unit=0,
answer=agent_loop.completion,
answer_price_unit=0,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_thought)
db.session.commit()
self._pub_handler.pub_agent_thought(message_agent_thought)
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = agent_message_unit_price
message_agent_thought.message_price_unit = agent_message_price_unit
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = agent_answer_unit_price
message_agent_thought.answer_price_unit = agent_answer_price_unit
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = agent_model_instance.get_currency()
db.session.commit()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery(
dataset_id=dataset_query_obj.dataset_id,
content=dataset_query_obj.query,
source='app',
source_app_id=self.app.id,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(dataset_query)
db.session.commit()
def on_dataset_query_finish(self, resource: List):
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self.message.id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self.user.id
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self.retriever_resource = resource
def on_message_replace(self, text: str):
if text is not None:
self._pub_handler.pub_message_replace(text)
def message_end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
def end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end()
class PubHandler:
def __init__(self, user: Union[Account, EndUser], task_id: str,
message: Message, conversation: Conversation,
chain_pub: bool = False, agent_thought_pub: bool = False):
self._channel = PubHandler.generate_channel_name(user, task_id)
self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
self._task_id = task_id
self._message = message
self._conversation = conversation
self._chain_pub = chain_pub
self._agent_thought_pub = agent_thought_pub
@classmethod
def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
if not user:
raise ValueError("user is required")
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result:{}-{}".format(user_str, task_id)
@classmethod
def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result_stopped:{}-{}".format(user_str, task_id)
def pub_text(self, text: str):
content = {
'event': 'message',
'data': {
'task_id': self._task_id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': str(self._conversation.id)
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_message_replace(self, text: str):
content = {
'event': 'message_replace',
'data': {
'task_id': self._task_id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': str(self._conversation.id)
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_chain(self, message_chain: MessageChain):
if self._chain_pub:
content = {
'event': 'chain',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'chain_id': message_chain.id,
'type': message_chain.type,
'input': json.loads(message_chain.input),
'output': json.loads(message_chain.output),
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
if self._agent_thought_pub:
content = {
'event': 'agent_thought',
'data': {
'id': message_agent_thought.id,
'task_id': self._task_id,
'message_id': self._message.id,
'chain_id': message_agent_thought.message_chain_id,
'position': message_agent_thought.position,
'thought': message_agent_thought.thought,
'tool': message_agent_thought.tool,
'tool_input': message_agent_thought.tool_input,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_message_end(self, retriever_resource: List):
content = {
'event': 'message_end',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
if retriever_resource:
content['data']['retriever_resources'] = retriever_resource
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self):
content = {
'event': 'end',
}
redis_client.publish(self._channel, json.dumps(content))
@classmethod
def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
content = {
'error': type(e).__name__,
'description': e.description if getattr(e, 'description', None) is not None else str(e)
}
channel = cls.generate_channel_name(user, task_id)
redis_client.publish(channel, json.dumps(content))
def _is_stopped(self):
return redis_client.get(self._stopped_cache_key) is not None
@classmethod
def ping(cls, user: Union[Account, EndUser], task_id: str):
content = {
'event': 'ping'
}
channel = cls.generate_channel_name(user, task_id)
redis_client.publish(channel, json.dumps(content))
@classmethod
def stop(cls, user: Union[Account, EndUser], task_id: str):
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
redis_client.setex(stopped_cache_key, 600, 1)
class ConversationTaskStoppedException(Exception):
pass
class ConversationTaskInterruptException(Exception):
pass

View File

@@ -3,7 +3,8 @@ from pathlib import Path
from typing import List, Union, Optional
import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
from flask import current_app
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
@@ -11,6 +12,13 @@ from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
from extensions.ext_storage import storage
from models.model import UploadFile
@@ -49,14 +57,34 @@ class FileExtractor:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
if is_automatic:
loader = UnstructuredFileLoader(
file_path, strategy="hi_res", mode="elements"
)
# loader = UnstructuredAPIFileLoader(
# file_path=filenames[0],
# api_key="FAKE_API_KEY",
# )
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension == '.docx':
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
elif file_extension == '.eml':
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
elif file_extension == '.ppt':
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
elif file_extension == '.pptx':
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
elif file_extension == '.xml':
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)

View File

@@ -0,0 +1,50 @@
import logging
import base64
from typing import List
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredEmailLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path, api_url=self._api_url)
# noinspection PyBroadException
try:
for element in elements:
element_text = element.text.strip()
padding_needed = 4 - len(element_text) % 4
element_text += '=' * padding_needed
element_decode = base64.b64decode(element_text)
soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser')
element.text = soup.get_text()
except Exception:
pass
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,48 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredMarkdownLoader(BaseLoader):
"""Load md files.
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
remove_images: Whether to remove images from the text.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.md import partition_md
elements = partition_md(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,40 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredMsgLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.msg import partition_msg
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,47 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class UnstructuredPPTLoader(BaseLoader):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
from unstructured.partition.ppt import partition_ppt
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
text_by_page = {}
for element in elements:
page = element.metadata.page_number
text = element.text
if page in text_by_page:
text_by_page[page] += "\n" + text
else:
text_by_page[page] = text
combined_texts = list(text_by_page.values())
documents = []
for combined_text in combined_texts:
text = combined_text.strip()
documents.append(Document(page_content=text))
return documents

Some files were not shown because too many files have changed in this diff Show More