mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 06:48:28 +00:00
Compare commits
73 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d70d61b1cb | ||
|
|
d069c668f8 | ||
|
|
e91dd28a76 | ||
|
|
595e9b25ba | ||
|
|
06d2d8cea3 | ||
|
|
936c3cc4d7 | ||
|
|
08abbb8dba | ||
|
|
972cf3cd01 | ||
|
|
da4847c5a8 | ||
|
|
08494058e9 | ||
|
|
9080ece3fb | ||
|
|
438912700c | ||
|
|
6b57e4e0ff | ||
|
|
6da3a33e6c | ||
|
|
2c8badfea9 | ||
|
|
91182a86bf | ||
|
|
0b7e0cadc0 | ||
|
|
163515c6e9 | ||
|
|
40d612ffc7 | ||
|
|
88a73ecdea | ||
|
|
64642fabc4 | ||
|
|
7083a05a25 | ||
|
|
9f3ed32d0f | ||
|
|
1521ac5563 | ||
|
|
695246d80a | ||
|
|
908164f6d5 | ||
|
|
96206b6108 | ||
|
|
e2fff7fd87 | ||
|
|
53690bfad2 | ||
|
|
ae37a7d998 | ||
|
|
7b37e05dec | ||
|
|
022450768f | ||
|
|
7c5661152e | ||
|
|
fb55b3a89a | ||
|
|
df1509983c | ||
|
|
185c2f86cd | ||
|
|
10fc44e2af | ||
|
|
c3275dfd36 | ||
|
|
43741ad5d1 | ||
|
|
8dec406161 | ||
|
|
58f8d74591 | ||
|
|
867fc61b12 | ||
|
|
8e2e477a7f | ||
|
|
9b34f5a9ff | ||
|
|
5e34f938c1 | ||
|
|
2fd56cb01c | ||
|
|
4f0e272549 | ||
|
|
1a5279a3ef | ||
|
|
7775f5785f | ||
|
|
2de73991ff | ||
|
|
354d033e60 | ||
|
|
ebc2cdad2e | ||
|
|
5bb841935e | ||
|
|
65fd4b39ce | ||
|
|
96d2de2258 | ||
|
|
a71f2863ac | ||
|
|
a9b942981d | ||
|
|
4b1ba2ec21 | ||
|
|
c09184fd94 | ||
|
|
b0d8d196e1 | ||
|
|
7c43123956 | ||
|
|
eede84eb9e | ||
|
|
b5b20234e9 | ||
|
|
5beb298e47 | ||
|
|
6b499b9a16 | ||
|
|
4c639961f5 | ||
|
|
dfd3f507fb | ||
|
|
d5695b3170 | ||
|
|
994fceece3 | ||
|
|
8c451eb0e6 | ||
|
|
79b4366203 | ||
|
|
3675d2eae8 | ||
|
|
38b55d2186 |
13
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
13
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -1,11 +1,18 @@
|
||||
name: "🕷️ Bug report"
|
||||
description: Report errors or unexpected behavior [please use English :)]
|
||||
description: Report errors or unexpected behavior
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
- type: markdown
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
10
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@@ -1,8 +1,16 @@
|
||||
name: "📚 Documentation Issue"
|
||||
description: Report issues in our documentation [please use English :)]
|
||||
description: Report issues in our documentation
|
||||
labels:
|
||||
- ducumentation
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of requested docs changes
|
||||
|
||||
11
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
11
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -1,8 +1,17 @@
|
||||
name: "⭐ Feature or enhancement request"
|
||||
description: Propose something new. [please use English :)]
|
||||
description: Propose something new.
|
||||
labels:
|
||||
- enhancement
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Description of the new feature / enhancement
|
||||
|
||||
9
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
9
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@@ -3,6 +3,15 @@ description: "Request help from the community" [please use English :)]
|
||||
labels:
|
||||
- help-wanted
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of the help you need
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
10
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@@ -3,9 +3,15 @@ description: Report incorrect translations. [please use English :)]
|
||||
labels:
|
||||
- translation
|
||||
body:
|
||||
- type: markdown
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
label: Dify version
|
||||
|
||||
58
.github/workflows/api-model-runtime-tests.yml
vendored
Normal file
58
.github/workflows/api-model-runtime-tests.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
- feat/model-runtime
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
|
||||
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
|
||||
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
CHATGLM_API_BASE: http://a.abc.com:11451
|
||||
XINFERENCE_SERVER_URL: http://a.abc.com:11451
|
||||
XINFERENCE_GENERATION_MODEL_UID: generate
|
||||
XINFERENCE_CHAT_MODEL_UID: chat
|
||||
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
|
||||
XINFERENCE_RERANK_MODEL_UID: rerank
|
||||
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
|
||||
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
|
||||
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
|
||||
MOCK_SWITCH: true
|
||||
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install -r api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
|
||||
38
.github/workflows/api-unit-tests.yml
vendored
38
.github/workflows/api-unit-tests.yml
vendored
@@ -1,38 +0,0 @@
|
||||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install -r api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest api/tests/unit_tests
|
||||
8
.github/workflows/build-api-image.yml
vendored
8
.github/workflows/build-api-image.yml
vendored
@@ -14,10 +14,10 @@ jobs:
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: langgenius/dify-api
|
||||
tags: |
|
||||
@@ -37,7 +37,7 @@ jobs:
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: "{{defaultContext}}:api"
|
||||
platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }}
|
||||
|
||||
8
.github/workflows/build-web-image.yml
vendored
8
.github/workflows/build-web-image.yml
vendored
@@ -14,10 +14,10 @@ jobs:
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: langgenius/dify-web
|
||||
tags: |
|
||||
@@ -37,7 +37,7 @@ jobs:
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: "{{defaultContext}}:web"
|
||||
platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }}
|
||||
|
||||
@@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
|
||||
|
||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
|
||||
|
||||
|
||||
### Provider Integrations
|
||||
If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
|
||||
|
||||
|
||||
### i18n (Internationalization) Support
|
||||
|
||||
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
|
||||
|
||||
15
README.md
15
README.md
@@ -4,7 +4,8 @@
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a> |
|
||||
<a href="./README_KL.md">Klingon</a>
|
||||
<a href="./README_KL.md">Klingon</a> |
|
||||
<a href="./README_FR.md">Français</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -20,19 +21,18 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
[v0.3.31:Surpassing the Assistants API – Dify's RAG Demonstrates an Impressive 20% Improvement.](https://dify.ai/blog/dify-ai-rag-technology-upgrade-performance-improvement-qa-accuracy)
|
||||
|
||||
**Dify** is an LLM application development platform that has already seen over **100,000** applications built on Dify.AI. It integrates the concepts of Backend as a Service and LLMOps, covering the core tech stack required for building generative AI-native applications, including a built-in RAG engine. With Dify, **you can self-deploy capabilities similar to Assistants API and GPTs based on any LLMs.**
|
||||
**Dify** is an LLM application development platform that has helped built over **100,000** applications. It integrates BaaS and LLMOps, covering the essential tech stack for building generative AI-native applications, including a built-in RAG engine. Dify allows you to **deploy your own version of Assistants API and GPTs, based on any LLMs.**
|
||||
|
||||

|
||||
|
||||
## Use Cloud Services
|
||||
|
||||
Using [Dify.AI Cloud](https://dify.ai) provides all the capabilities of the open-source version, and includes a complimentary 200 GPT trial credits.
|
||||
[Dify.AI Cloud](https://dify.ai) provides all the capabilities of the open-source version, and includes 200 free requests to OpenAI GPT-3.5.
|
||||
|
||||
## Why Dify
|
||||
|
||||
Dify features model neutrality and is a complete, engineered tech stack compared to hardcoded development libraries like LangChain. Unlike OpenAI's Assistants API, Dify allows for full local deployment of services.
|
||||
Dify is model-agnostic and boasts a comprehensive tech stack compared to hardcoded development libraries like LangChain. Unlike OpenAI's Assistants API, Dify allows for full local deployment of services.
|
||||
|
||||
| Feature | Dify.AI | Assistants API | LangChain |
|
||||
|---------|---------|----------------|-----------|
|
||||
@@ -59,6 +59,10 @@ Dify features model neutrality and is a complete, engineered tech stack compared
|
||||
|
||||
## Before You Start
|
||||
|
||||
**Star us, and you'll get instant notifications for all new releases on GitHub!**
|
||||
|
||||

|
||||
|
||||
- [Website](https://dify.ai)
|
||||
- [Docs](https://docs.dify.ai)
|
||||
- [Deployment Docs](https://docs.dify.ai/getting-started/install-self-hosted)
|
||||
@@ -104,6 +108,7 @@ If you need to customize the configuration, please refer to the comments in our
|
||||
|
||||
We welcome you to contribute to Dify to help make Dify better in various ways, submitting code, issues, new ideas, or sharing the interesting and useful AI applications you have created based on Dify. At the same time, we also welcome you to share Dify at different events, conferences, and social media.
|
||||
|
||||
- [Roadmap and Feedback](https://feedback.dify.ai/). Best for: sharing feedback and checking out our feature roadmap.
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs and errors you encounter using Dify.AI, see the [Contribution Guide](CONTRIBUTING.md).
|
||||
- [Email Support](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI.
|
||||
- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
|
||||
13
README_CN.md
13
README_CN.md
@@ -4,7 +4,8 @@
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a> |
|
||||
<a href="./README_KL.md">Klingon</a>
|
||||
<a href="./README_KL.md">Klingon</a> |
|
||||
<a href="./README_FR.md">Français</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -24,6 +25,10 @@ Dify 是一个 LLM 应用开发平台,已经有超过 10 万个应用基于 Di
|
||||
|
||||

|
||||
|
||||
## 使用云端服务
|
||||
|
||||
使用 [Dify.AI Cloud](https://dify.ai) 提供开源版本的所有功能,并包含 200 次 GPT 试用额度。
|
||||
|
||||
## 为什么选择 Dify
|
||||
|
||||
Dify 具有模型中立性,相较 LangChain 等硬编码开发库 Dify 是一个完整的、工程化的技术栈,而相较于 OpenAI 的 Assistants API 你可以完全将服务部署在本地。
|
||||
@@ -54,6 +59,10 @@ Dify 具有模型中立性,相较 LangChain 等硬编码开发库 Dify 是一
|
||||
|
||||
## 在开始之前
|
||||
|
||||
**关注我们,您将立即收到 GitHub 上所有新发布版本的通知!**
|
||||
|
||||

|
||||
|
||||
- [网站](https://dify.ai)
|
||||
- [文档](https://docs.dify.ai)
|
||||
- [部署文档](https://docs.dify.ai/getting-started/install-self-hosted)
|
||||
@@ -111,4 +120,4 @@ docker compose up -d
|
||||
|
||||
## License
|
||||
|
||||
本仓库遵循 [Dify Open Source License](LICENSE) 开源协议。
|
||||
本仓库遵循 [Dify Open Source License](LICENSE) 开源协议,该许可证本质上是 Apache 2.0,但有一些额外的限制。
|
||||
|
||||
10
README_ES.md
10
README_ES.md
@@ -3,7 +3,9 @@
|
||||
<a href="./README.md">English</a> |
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a>
|
||||
<a href="./README_ES.md">Español</a> |
|
||||
<a href="./README_KL.md">Klingon</a> |
|
||||
<a href="./README_FR.md">Français</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -56,6 +58,10 @@ Dify se caracteriza por su neutralidad de modelo y es un conjunto tecnológico c
|
||||
|
||||
## Antes de Empezar
|
||||
|
||||
**¡Danos una estrella, y recibirás notificaciones instantáneas de todos los nuevos lanzamientos en GitHub!**
|
||||
|
||||

|
||||
|
||||
- [Sitio web](https://dify.ai)
|
||||
- [Documentación](https://docs.dify.ai)
|
||||
- [Documentación de Implementación](https://docs.dify.ai/getting-started/install-self-hosted)
|
||||
@@ -109,4 +115,4 @@ Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En
|
||||
|
||||
## Licencia
|
||||
|
||||
Este repositorio está disponible bajo la [Licencia de código abierto de Dify](LICENSE).
|
||||
Este repositorio está disponible bajo la [Licencia de Código Abierto Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales.
|
||||
|
||||
120
README_FR.md
Normal file
120
README_FR.md
Normal file
@@ -0,0 +1,120 @@
|
||||
[](https://dify.ai)
|
||||
<p align="center">
|
||||
<a href="./README.md">English</a> |
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a> |
|
||||
<a href="./README_KL.md">Klingon</a> |
|
||||
<a href="./README_FR.md">Français</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/AI-Dify?logo=AI&logoColor=%20%23f5f5f5&label=Dify&labelColor=%20%23155EEF&color=%23EAECF0"></a>
|
||||
<a href="https://discord.gg/FngNHpbcY7" target="_blank">
|
||||
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord"
|
||||
alt="chat on Discord"></a>
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?style=social&logo=X"
|
||||
alt="follow on Twitter"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
**Dify** est une plateforme de développement d'applications LLM qui a déjà vu plus de **100,000** applications construites sur Dify.AI. Elle intègre les concepts de Backend as a Service et LLMOps, couvrant la pile technologique de base requise pour construire des applications natives d'IA générative, y compris un moteur RAG intégré. Avec Dify, **vous pouvez auto-déployer des capacités similaires aux API Assistants et GPT basées sur n'importe quels LLM.**
|
||||
|
||||

|
||||
|
||||
## Utiliser les services cloud
|
||||
|
||||
L'utilisation de [Dify.AI Cloud](https://dify.ai) fournit toutes les capacités de la version open source, et comprend un essai gratuit de 200 crédits GPT.
|
||||
|
||||
## Pourquoi Dify
|
||||
|
||||
Dify présente une neutralité de modèle et est une pile technologique complète et conçue par rapport à des bibliothèques de développement codées en dur comme LangChain. Contrairement à l'API Assistants d'OpenAI, Dify permet un déploiement local complet des services.
|
||||
|
||||
| Fonctionnalité | Dify.AI | API Assistants | LangChain |
|
||||
|---------------|----------|-----------------|------------|
|
||||
| **Approche de programmation** | Orientée API | Orientée API | Orientée code Python |
|
||||
| **Stratégie écosystème** | Open source | Fermé et commercial | Open source |
|
||||
| **Moteur RAG** | Pris en charge | Pris en charge | Non pris en charge |
|
||||
| **IDE d'invite** | Inclus | Inclus | Aucun |
|
||||
| **LLM pris en charge** | Grande variété | Seulement GPT | Grande variété |
|
||||
| **Déploiement local** | Pris en charge | Non pris en charge | Non applicable |
|
||||
|
||||
## Fonctionnalités
|
||||
|
||||

|
||||
|
||||
**1\. Support LLM**: Intégration avec la famille de modèles GPT d'OpenAI, ou les modèles de la famille open source Llama2. En fait, Dify prend en charge les modèles commerciaux grand public et les modèles open source (déployés localement ou basés sur MaaS).
|
||||
|
||||
**2\. IDE d'invite**: Orchestration visuelle d'applications et de services basés sur LLMs avec votre équipe.
|
||||
|
||||
**3\. Moteur RAG**: Comprend diverses capacités RAG basées sur l'indexation de texte intégral ou les embeddings de base de données vectorielles, permettant le chargement direct de PDF, TXT et autres formats de texte.
|
||||
|
||||
**4\. Agents**: Un framework d'agents basé sur l'appel de fonctions qui permet aux utilisateurs de configurer ce qu'ils voient est ce qu'ils obtiennent. Dify comprend des capacités de plug-in de base comme Google Search.
|
||||
|
||||
**5\. Opérations continues**: Surveillez et analysez les journaux et les performances des applications, améliorez en continu les invites, les datasets ou les modèles à l'aide de données de production.
|
||||
|
||||
## Avant de commencer
|
||||
|
||||
**Étoilez-nous, et vous recevrez des notifications instantanées pour toutes les nouvelles sorties sur GitHub !**
|
||||

|
||||
|
||||
- [Site web](https://dify.ai)
|
||||
- [Documentation](https://docs.dify.ai)
|
||||
- [Documentation de déploiement](https://docs.dify.ai/getting-started/install-self-hosted)
|
||||
- [FAQ](https://docs.dify.ai/getting-started/faq)
|
||||
|
||||
|
||||
## Installer la version Communauté
|
||||
|
||||
### Configuration système
|
||||
|
||||
Avant d'installer Dify, assurez-vous que votre machine répond aux exigences minimales suivantes:
|
||||
|
||||
- CPU >= 2 cœurs
|
||||
- RAM >= 4 Go
|
||||
|
||||
### Démarrage rapide
|
||||
|
||||
La façon la plus simple de démarrer le serveur Dify est d'exécuter notre fichier [docker-compose.yml](docker/docker-compose.yaml). Avant d'exécuter la commande d'installation, assurez-vous que [Docker](https://docs.docker.com/get-docker/) et [Docker Compose](https://docs.docker.com/compose/install/) sont installés sur votre machine:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
Après l'exécution, vous pouvez accéder au tableau de bord Dify dans votre navigateur à l'adresse [http://localhost/install](http://localhost/install) et démarrer le processus d'installation initiale.
|
||||
|
||||
### Chart Helm
|
||||
|
||||
Un grand merci à @BorisPolonsky pour nous avoir fourni une version [Helm Chart](https://helm.sh/) qui permet le déploiement de Dify sur Kubernetes.
|
||||
Vous pouvez accéder à https://github.com/BorisPolonsky/dify-helm pour des informations de déploiement.
|
||||
|
||||
### Configuration
|
||||
|
||||
Si vous avez besoin de personnaliser la configuration, veuillez vous référer aux commentaires de notre fichier [docker-compose.yml](docker/docker-compose.yaml) et définir manuellement la configuration de l'environnement. Après avoir apporté les modifications, veuillez exécuter à nouveau `docker-compose up -d`. Vous trouverez la liste complète des variables d'environnement dans notre [documentation](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
## Historique d'étoiles
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
|
||||
## Communauté & Support
|
||||
|
||||
Nous vous invitons à contribuer à Dify pour aider à améliorer Dify de diverses manières, en soumettant du code, des problèmes, de nouvelles idées ou en partageant les applications d'IA intéressantes et utiles que vous avez créées sur la base de Dify. En même temps, nous vous invitons également à partager Dify lors de différents événements, conférences et réseaux sociaux.
|
||||
|
||||
- [Problèmes GitHub](https://github.com/langgenius/dify/issues). Idéal pour : les bogues et les erreurs que vous rencontrez en utilisant Dify.AI, voir le [Guide de contribution](CONTRIBUTING.md).
|
||||
- [Support par courriel](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Idéal pour : les questions que vous avez au sujet de l'utilisation de Dify.AI.
|
||||
- [Discord](https://discord.gg/FngNHpbcY7). Idéal pour : partager vos applications et discuter avec la communauté.
|
||||
- [Twitter](https://twitter.com/dify_ai). Idéal pour : partager vos applications et discuter avec la communauté.
|
||||
- [Licence commerciale](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Idéal pour : les demandes commerciales de licence de Dify.AI pour un usage commercial.
|
||||
|
||||
## Divulgation de la sécurité
|
||||
|
||||
Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Envoyez plutôt vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée.
|
||||
|
||||
## Licence
|
||||
|
||||
Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement Apache 2.0 avec quelques restrictions supplémentaires.
|
||||
24
README_JA.md
24
README_JA.md
@@ -4,7 +4,8 @@
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a> |
|
||||
<a href="./README_KL.md">Klingon</a>
|
||||
<a href="./README_KL.md">Klingon</a> |
|
||||
<a href="./README_FR.md">Français</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -24,6 +25,8 @@
|
||||
|
||||
Please note that translating complex technical terms can sometimes result in slight variations in meaning due to differences in language nuances.
|
||||
|
||||

|
||||
|
||||
## クラウドサービスの利用
|
||||
|
||||
[Dify.AI Cloud](https://dify.ai) を使用すると、オープンソース版の全機能を利用でき、さらに200GPTのトライアルクレジットが無料で提供されます。
|
||||
@@ -41,9 +44,26 @@ Difyはモデルニュートラルであり、LangChainのようなハードコ
|
||||
| **サポートされるLLMs** | 豊富な種類 | GPTのみ | 豊富な種類 |
|
||||
| **ローカルデプロイメント** | サポート済み | 非サポート | 該当なし |
|
||||
|
||||
## 機能
|
||||
|
||||

|
||||
|
||||
**1\. LLMサポート**: OpenAIのGPTファミリーモデルやLlama2ファミリーのオープンソースモデルとの統合。 実際、Difyは主要な商用モデルとオープンソースモデル(ローカルでデプロイまたはMaaSベース)をサポートしています。
|
||||
|
||||
**2\. プロンプトIDE**: チームとのLLMベースのアプリケーションとサービスの視覚的なオーケストレーション。
|
||||
|
||||
**3\. RAGエンジン**: フルテキストインデックスまたはベクトルデータベース埋め込みに基づくさまざまなRAG機能を含み、PDF、TXT、その他のテキストフォーマットの直接アップロードを可能にします。
|
||||
|
||||
**4\. エージェント**: ユーザーが sees what they get を設定できる関数呼び出しベースのエージェントフレームワーク。 Difyには、Google検索などの基本的なプラグイン機能が含まれています。
|
||||
|
||||
**5\. 継続的運用**: アプリケーションログとパフォーマンスを監視および分析し、運用データを使用してプロンプト、データセット、またはモデルを継続的に改善します。
|
||||
|
||||
## 開始する前に
|
||||
|
||||
**私たちをスターして、GitHub上でのすべての新しいリリースに対する即時通知を受け取ります!**
|
||||
|
||||

|
||||
|
||||
- [Website](https://dify.ai)
|
||||
- [Docs](https://docs.dify.ai)
|
||||
- [Deployment Docs](https://docs.dify.ai/getting-started/install-self-hosted)
|
||||
@@ -100,4 +120,4 @@ Difyに貢献していただき、コードの提出、問題の報告、新し
|
||||
|
||||
## ライセンス
|
||||
|
||||
このリポジトリは、[Dify Open Source License](LICENSE) のもとで利用できます。
|
||||
このリポジトリは、基本的にApache 2.0にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用できます。
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a> |
|
||||
<a href="./README_KL.md">Klingon</a>
|
||||
<a href="./README_KL.md">Klingon</a> |
|
||||
<a href="./README_FR.md">Français</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -57,6 +58,10 @@ Dify Daq rIn neutrality 'ej Hoch, LangChain tInHar HubwI'. maH Daqbe'law' Qawqar
|
||||
|
||||
## Do'wI' qabmey lo'taH
|
||||
|
||||
**maHvaD jatlhchugh, GitHub Daq Hoch chu' ghompu'vam tIqel yInob!**
|
||||
|
||||

|
||||
|
||||
- [Website](https://dify.ai)
|
||||
- [Docs](https://docs.dify.ai)
|
||||
- [lo'taHmoH Docs](https://docs.dify.ai/getting-started/install-self-hosted)
|
||||
|
||||
@@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE=
|
||||
HOSTED_OPENAI_API_ORGANIZATION=
|
||||
HOSTED_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_OPENAI_PAID_ENABLED=false
|
||||
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED=false
|
||||
HOSTED_AZURE_OPENAI_API_KEY=
|
||||
@@ -119,16 +117,6 @@ HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
|
||||
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
|
||||
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
|
||||
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
|
||||
|
||||
# Stripe configuration
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Billing configuration
|
||||
BILLING_API_URL=http://127.0.0.1:8000/v1
|
||||
BILLING_API_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_BILLING_SECRET=
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
15
api/.vscode/launch.json
vendored
15
api/.vscode/launch.json
vendored
@@ -4,6 +4,21 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"justMyCode": true,
|
||||
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Python: Flask",
|
||||
"type": "python",
|
||||
|
||||
@@ -34,9 +34,6 @@ RUN apt-get update \
|
||||
COPY --from=base /pkg /usr/local
|
||||
COPY . /app/api/
|
||||
|
||||
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
|
||||
ENV TRANSFORMERS_OFFLINE true
|
||||
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
|
||||
35
api/app.py
35
api/app.py
@@ -6,9 +6,12 @@ from werkzeug.exceptions import Unauthorized
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
import grpc.experimental.gevent
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
# if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
import grpc.experimental.gevent
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import langchain
|
||||
langchain.verbose = True
|
||||
|
||||
import time
|
||||
import logging
|
||||
@@ -18,9 +21,8 @@ import threading
|
||||
from flask import Flask, request, Response
|
||||
from flask_cors import CORS
|
||||
|
||||
from core.model_providers.providers import hosted
|
||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
|
||||
ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
@@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
|
||||
register_blueprints(app)
|
||||
register_commands(app)
|
||||
|
||||
hosted.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -95,8 +95,8 @@ def initialize_extensions(app):
|
||||
ext_celery.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_hosting_provider.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_stripe.init_app(app)
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
@@ -106,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
|
||||
if request.blueprint == 'console':
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
if not auth_header:
|
||||
auth_token = request.args.get('_token')
|
||||
if not auth_token:
|
||||
raise Unauthorized('Invalid Authorization token.')
|
||||
else:
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get('user_id')
|
||||
|
||||
|
||||
@@ -12,23 +12,19 @@ import qdrant_client
|
||||
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
|
||||
from tqdm import tqdm
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.hosted import hosted_model_providers
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from libs.password import password_pattern, valid_password, hash_password
|
||||
from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
|
||||
from models.model import Account, AppModelConfig, App
|
||||
from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
@@ -327,6 +323,8 @@ def create_qdrant_indexes():
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
@@ -334,19 +332,23 @@ def create_qdrant_indexes():
|
||||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
except Exception:
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
@@ -752,6 +754,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('add-annotation-question-field-value', help='add annotation question value')
|
||||
def add_annotation_question_field_value():
|
||||
click.echo(click.style('Start add annotation question value.', fg='green'))
|
||||
message_annotations = db.session.query(MessageAnnotation).all()
|
||||
message_annotation_deal_count = 0
|
||||
if message_annotations:
|
||||
for message_annotation in message_annotations:
|
||||
try:
|
||||
if message_annotation.message_id and not message_annotation.question:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_annotation.message_id
|
||||
).first()
|
||||
message_annotation.question = message.query
|
||||
db.session.add(message_annotation)
|
||||
db.session.commit()
|
||||
message_annotation_deal_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@@ -766,3 +792,4 @@ def register_commands(app):
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
app.cli.add_command(add_qdrant_full_text_index)
|
||||
app.cli.add_command(add_annotation_question_field_value)
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
import dotenv
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@@ -44,15 +41,11 @@ DEFAULTS = {
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
||||
'HOSTED_MODERATION_ENABLED': 'False',
|
||||
'HOSTED_MODERATION_PROVIDERS': '',
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
@@ -61,7 +54,10 @@ DEFAULTS = {
|
||||
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
|
||||
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
|
||||
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64',
|
||||
'INVITE_EXPIRY_HOURS': 72
|
||||
'INVITE_EXPIRY_HOURS': 72,
|
||||
'BILLING_ENABLED': 'False',
|
||||
'CAN_REPLACE_LOGO': 'False',
|
||||
'ETL_TYPE': 'dify',
|
||||
}
|
||||
|
||||
|
||||
@@ -91,7 +87,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.3.33"
|
||||
self.CURRENT_VERSION = "0.4.0"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -268,8 +264,6 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
@@ -281,14 +275,15 @@ class Config:
|
||||
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
self.ETL_TYPE = get_env('ETL_TYPE')
|
||||
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
|
||||
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
|
||||
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
@@ -302,6 +297,3 @@ class CloudEditionConfig(Config):
|
||||
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
@@ -6,10 +6,10 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Import other controllers
|
||||
from . import extension, setup, version, apikey, admin
|
||||
from . import extension, setup, version, apikey, admin, feature
|
||||
|
||||
# Import app controllers
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
@@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
|
||||
from .workspace import workspace, members, model_providers, account, tool_providers, models
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
||||
@@ -26,7 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import chat, conversation, message, parameter, audio
|
||||
|
||||
# Import webhook controllers
|
||||
from .webhook import stripe
|
||||
|
||||
from .billing import billing
|
||||
|
||||
290
api/controllers/console/app/annotation.py
Normal file
290
api/controllers/console/app/annotation.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, marshal
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import NoFileUploadedError
|
||||
from controllers.console.datasets.error import TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \
|
||||
annotation_hit_history_fields
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from flask import request
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
|
||||
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
if action == 'enable':
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
elif action == 'disable':
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
else:
|
||||
raise ValueError('Unsupported annotation reply action')
|
||||
return result, 200
|
||||
|
||||
|
||||
class AppAnnotationSettingDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
return result, 200
|
||||
|
||||
|
||||
class AppAnnotationSettingUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
return result, 200
|
||||
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields),
|
||||
'has_more': len(annotation_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields)
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
class AnnotationBatchImportStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
||||
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id, annotation_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
|
||||
page, limit)
|
||||
response = {
|
||||
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
'has_more': len(annotation_hit_history_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
|
||||
api.add_resource(AnnotationReplyActionStatusApi,
|
||||
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
|
||||
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
|
||||
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
|
||||
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
|
||||
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
|
||||
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
|
||||
@@ -4,6 +4,10 @@ import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -13,9 +17,7 @@ from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
|
||||
app_detail_fields_with_site
|
||||
@@ -73,39 +75,41 @@ class AppListApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
default_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
provider_manager = ProviderManager()
|
||||
default_model_entity = provider_manager.get_default_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
default_model = None
|
||||
default_model_entity = None
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
default_model = None
|
||||
default_model_entity = None
|
||||
|
||||
if args['model_config'] is not None:
|
||||
# validate config
|
||||
model_config_dict = args['model_config']
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
current_user.current_tenant_id,
|
||||
model_config_dict["model"]["provider"]
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
if not default_model:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
|
||||
model_config_dict["model"]["name"] = default_model.name
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=model_config_dict,
|
||||
mode=args['mode']
|
||||
app_mode=args['mode']
|
||||
)
|
||||
|
||||
app = App(
|
||||
@@ -129,21 +133,27 @@ class AppListApi(Resource):
|
||||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
current_user.current_tenant_id,
|
||||
app_model_config.model_dict["provider"]
|
||||
)
|
||||
model_manager = ModelManager()
|
||||
|
||||
if not model_provider:
|
||||
if not default_model:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = default_model.model_provider.provider_name
|
||||
model_dict['name'] = default_model.name
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
try:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = model_instance.provider
|
||||
model_dict['name'] = model_instance.model
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
app.name = args['name']
|
||||
app.mode = args['mode']
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
@@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from flask_restful import Resource
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
@@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
||||
@@ -5,6 +5,10 @@ from typing import Generator, Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
@@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
||||
ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
@@ -56,7 +58,7 @@ class CompletionMessageApi(Resource):
|
||||
app_model=app_model,
|
||||
user=account,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming,
|
||||
is_model_config_override=True
|
||||
)
|
||||
@@ -75,8 +77,7 @@ class CompletionMessageApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource):
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
PubHandler.stop(account, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -132,7 +133,7 @@ class ChatMessageApi(Resource):
|
||||
app_model=app_model,
|
||||
user=account,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming,
|
||||
is_model_config_override=True
|
||||
)
|
||||
@@ -151,8 +152,7 @@ class ChatMessageApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
@@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource):
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
PubHandler.stop(account, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
@@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from flask_login import current_user
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
@@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
@@ -36,8 +37,7 @@ class RuleGenerateApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
|
||||
return rules
|
||||
|
||||
@@ -6,22 +6,24 @@ from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
|
||||
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields, annotation_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@@ -151,44 +153,24 @@ class MessageAnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# get app info
|
||||
app = _get_app(app_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
|
||||
parser.add_argument('content', type=str, location='json')
|
||||
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
|
||||
message_id = str(args['message_id'])
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
annotation = message.annotation
|
||||
|
||||
if annotation:
|
||||
annotation.content = args['content']
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args['content'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
return annotation
|
||||
|
||||
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@@ -227,7 +209,13 @@ class MessageMoreLikeThisApi(Resource):
|
||||
app_model = _get_app(app_id, 'completion')
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming
|
||||
)
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@@ -239,8 +227,7 @@ class MessageMoreLikeThisApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -268,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(
|
||||
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
@@ -309,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
|
||||
"""Modify app model config"""
|
||||
app_id = str(app_id)
|
||||
|
||||
app_model = _get_app(app_id)
|
||||
app = _get_app(app_id)
|
||||
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json,
|
||||
mode=app_model.mode
|
||||
app_mode=app.mode
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
app_id=app.id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app.app_model_config_id = new_app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(
|
||||
app_model,
|
||||
app,
|
||||
app_model_config=new_app_model_config
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import stripe
|
||||
import os
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_login import current_user
|
||||
from flask import current_app, request
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
@@ -13,20 +9,6 @@ from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class BillingInfo(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
edition = current_app.config['EDITION']
|
||||
if edition != 'CLOUD':
|
||||
return {"enabled": False}
|
||||
|
||||
return BillingService.get_info(current_user.current_tenant_id)
|
||||
|
||||
|
||||
class Subscription(Resource):
|
||||
|
||||
@setup_required
|
||||
@@ -40,7 +22,12 @@ class Subscription(Resource):
|
||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||
args = parser.parse_args()
|
||||
|
||||
return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
|
||||
return BillingService.get_subscription(args['plan'],
|
||||
args['interval'],
|
||||
current_user.email,
|
||||
current_user.current_tenant_id)
|
||||
|
||||
|
||||
class Invoices(Resource):
|
||||
@@ -50,36 +37,9 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
return BillingService.get_invoices(current_user.email)
|
||||
|
||||
|
||||
class StripeBillingWebhook(Resource):
|
||||
|
||||
@setup_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
payload = request.data
|
||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
return 'Invalid payload', 400
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
return 'Invalid signature', 400
|
||||
|
||||
BillingService.process_event(event)
|
||||
|
||||
return 'success', 200
|
||||
|
||||
|
||||
api.add_resource(BillingInfo, '/billing/info')
|
||||
api.add_resource(Subscription, '/billing/subscription')
|
||||
api.add_resource(Invoices, '/billing/invoices')
|
||||
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')
|
||||
|
||||
@@ -4,6 +4,8 @@ from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console.apikey import api_key_list, api_key_fields
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
@@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
@@ -23,7 +24,6 @@ from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile, ApiToken
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@@ -55,16 +55,20 @@ class DatasetListApi(Resource):
|
||||
current_user.current_tenant_id, current_user)
|
||||
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
# if len(valid_model_list) == 0:
|
||||
# raise ProviderNotInitializeError(
|
||||
# f"No Embedding Model available. Please configure a valid provider "
|
||||
# f"in the Settings -> Model Provider.")
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
@@ -75,6 +79,7 @@ class DatasetListApi(Resource):
|
||||
item['embedding_available'] = False
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
@@ -130,13 +135,20 @@ class DatasetApi(Resource):
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
# get valid model list
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request, current_app
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
@@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import document_with_segments_fields, document_fields, \
|
||||
dataset_and_document_fields, document_status_fields
|
||||
@@ -272,10 +275,12 @@ class DatasetInitApi(Resource):
|
||||
args = parser.parse_args()
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
@@ -410,7 +415,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if dataset.data_source_type == 'upload_file':
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id in info_list
|
||||
UploadFile.id.in_(info_list)
|
||||
).all()
|
||||
|
||||
if file_details is None:
|
||||
|
||||
@@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from libs.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
@@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
@@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
|
||||
@@ -69,5 +69,20 @@ class FilePreviewApi(Resource):
|
||||
return {'content': text}
|
||||
|
||||
|
||||
class FileSupportTypeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
if etl_type == 'Unstructured':
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
|
||||
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
|
||||
else:
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
return {'allowed_extensions': allowed_extensions}
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
|
||||
api.add_resource(FileSupportTypeApi, '/files/support-type')
|
||||
|
||||
@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
|
||||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
@@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
||||
@@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
@@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource):
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource):
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
PubHandler.stop(current_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource):
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
PubHandler.stop(current_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
||||
@@ -73,7 +73,7 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator, Union
|
||||
|
||||
from flask import stream_with_context, Response
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, fields, marshal_with
|
||||
from flask_restful import reqparse, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound, InternalServerError
|
||||
|
||||
@@ -13,12 +13,14 @@ import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
|
||||
NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
@@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
14
api/controllers/console/feature.py
Normal file
14
api/controllers/console/feature.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from flask_restful import Resource
|
||||
from flask_login import current_user
|
||||
|
||||
from . import api
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class FeatureApi(Resource):
|
||||
|
||||
def get(self):
|
||||
return FeatureService.get_features(current_user.current_tenant_id).dict()
|
||||
|
||||
|
||||
api.add_resource(FeatureApi, '/features')
|
||||
@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
|
||||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
@@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
||||
@@ -12,9 +12,10 @@ from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
@@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True,
|
||||
is_model_config_override=True,
|
||||
)
|
||||
@@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||
|
||||
class UniversalChatStopApi(UniversalChatResource):
|
||||
def post(self, universal_app, task_id):
|
||||
PubHandler.stop(current_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
||||
@@ -66,7 +66,7 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
@@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import logging
|
||||
|
||||
import stripe
|
||||
from flask import request, current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
|
||||
|
||||
class StripeWebhookApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
payload = request.data
|
||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
return 'Invalid payload', 400
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
return 'Invalid signature', 400
|
||||
|
||||
# Handle the checkout.session.completed event
|
||||
if event['type'] == 'checkout.session.completed':
|
||||
logging.debug(event['data']['object']['id'])
|
||||
logging.debug(event['data']['object']['amount_subtotal'])
|
||||
logging.debug(event['data']['object']['currency'])
|
||||
logging.debug(event['data']['object']['payment_intent'])
|
||||
logging.debug(event['data']['object']['payment_status'])
|
||||
logging.debug(event['data']['object']['metadata'])
|
||||
|
||||
session = stripe.checkout.Session.retrieve(
|
||||
event['data']['object']['id'],
|
||||
expand=['line_items'],
|
||||
)
|
||||
|
||||
logging.debug(session.line_items['data'][0]['quantity'])
|
||||
|
||||
# Fulfill the purchase...
|
||||
provider_checkout_service = ProviderCheckoutService()
|
||||
|
||||
try:
|
||||
provider_checkout_service.fulfill_provider_order(event, session.line_items)
|
||||
except Exception as e:
|
||||
|
||||
logging.debug(str(e))
|
||||
return 'success', 200
|
||||
|
||||
return 'success', 200
|
||||
|
||||
|
||||
api.add_resource(StripeWebhookApi, '/webhook/stripe')
|
||||
@@ -1,16 +1,19 @@
|
||||
import io
|
||||
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
from services.provider_service import ProviderService
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ModelProviderListApi(Resource):
|
||||
@@ -22,13 +25,36 @@ class ModelProviderListApi(Resource):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
|
||||
parser.add_argument('model_type', type=str, required=False, nullable=True,
|
||||
choices=[mt.value for mt in ModelType], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
|
||||
model_provider_service = ModelProviderService()
|
||||
provider_list = model_provider_service.get_provider_list(
|
||||
tenant_id=tenant_id,
|
||||
model_type=args.get('model_type')
|
||||
)
|
||||
|
||||
return provider_list
|
||||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return {
|
||||
"credentials": credentials
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_config_validate(
|
||||
provider_name=provider_name,
|
||||
config=args['config']
|
||||
model_provider_service.provider_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
@@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderUpdateApi(Resource):
|
||||
class ModelProviderApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
provider_service.save_custom_provider_config(
|
||||
model_provider_service.save_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
config=args['config']
|
||||
provider=provider,
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
@@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_name: str):
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.delete_custom_provider(
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
class ModelProviderIconApi(Resource):
|
||||
"""
|
||||
Get model provider icon
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_model_config_validate(
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type'],
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderModelUpdateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
|
||||
parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
provider_service.add_or_save_custom_provider_model_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type'],
|
||||
config=args['config']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_name: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.delete_custom_provider_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type=args['model_type']
|
||||
def get(self, provider: str, icon_type: str, lang: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||
provider=provider,
|
||||
icon_type=icon_type,
|
||||
lang=lang
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@@ -203,88 +159,50 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
|
||||
choices=['system', 'custom'], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.switch_preferred_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.switch_preferred_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
preferred_provider_type=args['preferred_provider_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
try:
|
||||
parameter_rules = provider_service.get_model_parameter_rules(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=provider_name,
|
||||
model_name=args['model_name'],
|
||||
model_type='text-generation'
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"Current Text Generation Model is invalid. Please switch to the available model.")
|
||||
|
||||
rules = {
|
||||
k: {
|
||||
'enabled': v.enabled,
|
||||
'min': v.min,
|
||||
'max': v.max,
|
||||
'default': v.default,
|
||||
'precision': v.precision
|
||||
}
|
||||
for k, v in vars(parameter_rules).items()
|
||||
}
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
provider_service = ProviderCheckoutService()
|
||||
provider_checkout = provider_service.create_checkout(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
account=current_user
|
||||
)
|
||||
def get(self, provider: str):
|
||||
if provider != 'anthropic':
|
||||
raise ValueError(f'provider name {provider} is invalid')
|
||||
|
||||
return {
|
||||
'url': provider_checkout.get_checkout_url()
|
||||
}
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account_id=current_user.id)
|
||||
return data
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_name: str):
|
||||
provider_service = ProviderService()
|
||||
result = provider_service.free_quota_submit(
|
||||
def post(self, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_submit(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -294,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
result = provider_service.free_quota_qualification_verify(
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_qualification_verify(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider=provider,
|
||||
token=args['token']
|
||||
)
|
||||
|
||||
@@ -310,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
|
||||
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
|
||||
api.add_resource(ModelProviderModelValidateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models/validate')
|
||||
api.add_resource(ModelProviderModelUpdateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models')
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
|
||||
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
|
||||
api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
|
||||
api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
|
||||
'<string:icon_type>/<string:lang>')
|
||||
|
||||
api.add_resource(PreferredProviderTypeUpdateApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
|
||||
api.add_resource(ModelProviderModelParameterRuleApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
|
||||
'/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
|
||||
'/workspaces/current/model-providers/<string:provider>/checkout-url')
|
||||
api.add_resource(ModelProviderFreeQuotaSubmitApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
|
||||
'/workspaces/current/model-providers/<string:provider>/free-quota-submit')
|
||||
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
|
||||
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')
|
||||
'/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import reqparse, Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from models.provider import ProviderType
|
||||
from services.provider_service import ProviderService
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class DefaultModelApi(Resource):
|
||||
@@ -21,52 +22,20 @@ class DefaultModelApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
||||
choices=[mt.value for mt in ModelType], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
provider_service = ProviderService()
|
||||
default_model = provider_service.get_default_model_of_model_type(
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
tenant_id,
|
||||
default_model.provider_name
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
return {
|
||||
'model_name': default_model.model_name,
|
||||
'model_type': default_model.model_type,
|
||||
'model_provider': {
|
||||
'provider_name': default_model.provider_name
|
||||
}
|
||||
}
|
||||
|
||||
provider = model_provider.provider
|
||||
rst = {
|
||||
'model_name': default_model.model_name,
|
||||
'model_type': default_model.model_type,
|
||||
'model_provider': {
|
||||
'provider_name': provider.provider_name,
|
||||
'provider_type': provider.provider_type
|
||||
}
|
||||
}
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
rst['model_provider']['quota_type'] = provider.quota_type
|
||||
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
|
||||
rst['model_provider']['quota_limit'] = provider.quota_limit
|
||||
rst['model_provider']['quota_used'] = provider.quota_used
|
||||
|
||||
return rst
|
||||
return jsonable_encoder({
|
||||
"data": default_model_entity
|
||||
})
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -76,15 +45,26 @@ class DefaultModelApi(Resource):
|
||||
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args['model_settings']
|
||||
for model_setting in model_settings:
|
||||
if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
|
||||
raise ValueError('invalid model type')
|
||||
|
||||
if 'provider' not in model_setting:
|
||||
continue
|
||||
|
||||
if 'model' not in model_setting:
|
||||
raise ValueError('invalid model')
|
||||
|
||||
try:
|
||||
provider_service.update_default_model_of_model_type(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_service.update_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_setting['model_type'],
|
||||
provider_name=model_setting['provider_name'],
|
||||
model_name=model_setting['model_name']
|
||||
provider=model_setting['provider'],
|
||||
model=model_setting['model']
|
||||
)
|
||||
except Exception:
|
||||
logging.warning(f"{model_setting['model_type']} save error")
|
||||
@@ -92,22 +72,198 @@ class DefaultModelApi(Resource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ValidModelApi(Resource):
|
||||
class ModelProviderModelApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
return jsonable_encoder({
|
||||
"data": models
|
||||
})
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type'],
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args['model_type'],
|
||||
model=args['model']
|
||||
)
|
||||
|
||||
return {
|
||||
"credentials": credentials
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
model_provider_service.model_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type'],
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model']
|
||||
)
|
||||
|
||||
return jsonable_encoder({
|
||||
"data": parameter_rules
|
||||
})
|
||||
|
||||
|
||||
class ModelProviderAvailableModelApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type):
|
||||
ModelType.value_of(model_type)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
provider_service = ProviderService()
|
||||
valid_models = provider_service.get_valid_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
return valid_models
|
||||
return jsonable_encoder({
|
||||
"data": models
|
||||
})
|
||||
|
||||
|
||||
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
|
||||
api.add_resource(ModelProviderModelCredentialApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/credentials')
|
||||
api.add_resource(ModelProviderModelValidateApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
|
||||
|
||||
api.add_resource(ModelProviderModelParameterRuleApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
|
||||
api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
|
||||
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
|
||||
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
class ProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
"""
|
||||
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
|
||||
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
|
||||
rest is replaced by * and the last two bits are displayed in plaintext
|
||||
|
||||
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_info_list = provider_service.get_provider_list(tenant_id)
|
||||
|
||||
provider_list = [
|
||||
{
|
||||
'provider_name': p['provider_name'],
|
||||
'provider_type': p['provider_type'],
|
||||
'is_valid': p['is_valid'],
|
||||
'last_used': p['last_used'],
|
||||
'is_enabled': p['is_valid'],
|
||||
**({
|
||||
'quota_type': p['quota_type'],
|
||||
'quota_limit': p['quota_limit'],
|
||||
'quota_used': p['quota_used']
|
||||
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
|
||||
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
|
||||
if p['config'] else None
|
||||
}
|
||||
for name, provider_info in provider_info_list.items()
|
||||
for p in provider_info['providers']
|
||||
]
|
||||
|
||||
return provider_list
|
||||
|
||||
|
||||
class ProviderTokenApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if provider == 'openai':
|
||||
args['token'] = {
|
||||
'openai_api_key': args['token']
|
||||
}
|
||||
|
||||
provider_service = ProviderService()
|
||||
try:
|
||||
provider_service.save_custom_provider_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider,
|
||||
config=args['token']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ProviderTokenValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
if provider == 'openai':
|
||||
args['token'] = {
|
||||
'openai_api_key': args['token']
|
||||
}
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_config_validate(
|
||||
provider_name=provider,
|
||||
config=args['token']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
|
||||
endpoint='workspaces_current_providers_token') # PUT for updating provider token
|
||||
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
|
||||
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
|
||||
|
||||
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
|
||||
@@ -10,12 +10,15 @@ from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, UnsupportedFileTypeError
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
import services
|
||||
from services.account_service import TenantService
|
||||
from services.workspace_service import WorkspaceService
|
||||
from services.file_service import FileService
|
||||
|
||||
provider_fields = {
|
||||
'provider_name': fields.String,
|
||||
@@ -31,9 +34,9 @@ tenant_fields = {
|
||||
'status': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'role': fields.String,
|
||||
'providers': fields.List(fields.Nested(provider_fields)),
|
||||
'in_trial': fields.Boolean,
|
||||
'trial_end_reason': fields.String,
|
||||
'custom_config': fields.Raw(attribute='custom_config'),
|
||||
}
|
||||
|
||||
tenants_fields = {
|
||||
@@ -130,6 +133,61 @@ class SwitchWorkspaceApi(Resource):
|
||||
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
|
||||
|
||||
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
|
||||
|
||||
|
||||
class CustomConfigWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('workspace_custom')
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('remove_webapp_brand', type=bool, location='json')
|
||||
parser.add_argument('replace_webapp_logo', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
custom_config_dict = {
|
||||
'remove_webapp_brand': args['remove_webapp_brand'],
|
||||
'replace_webapp_logo': args['replace_webapp_logo'],
|
||||
}
|
||||
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
|
||||
|
||||
tenant.custom_config_dict = custom_config_dict
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
class WebappLogoWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('workspace_custom')
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ['svg', 'png']:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, current_user, True)
|
||||
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return { 'id': upload_file.id }, 201
|
||||
|
||||
|
||||
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
|
||||
@@ -137,3 +195,5 @@ api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all ten
|
||||
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
|
||||
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
|
||||
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
|
||||
api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
|
||||
api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask import current_app, abort
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from services.billing_service import BillingService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def account_initialization_required(view):
|
||||
@@ -49,18 +49,23 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
tenant_id = current_user.current_tenant_id
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
members = billing_info['members']
|
||||
apps = billing_info['apps']
|
||||
vector_space = billing_info['vector_space']
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if resource == 'members' and 0 < members['limit'] <= members['size']:
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
|
||||
if resource == 'members' and 0 < members.limit <= members.size:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
|
||||
elif resource == 'apps' and 0 < apps.limit <= apps.size:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
|
||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'workspace_custom' and not features.can_replace_logo:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from flask import request, Response
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.files import api
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.file_service import FileService
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
@@ -29,9 +31,30 @@ class ImagePreviewApi(Resource):
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
class WorkspaceWebappLogoApi(Resource):
|
||||
def get(self, workspace_id):
|
||||
workspace_id = str(workspace_id)
|
||||
|
||||
custom_config = TenantService.get_custom_config(workspace_id)
|
||||
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
raise NotFound(f'webapp logo is not found')
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
webapp_logo_file_id,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
|
||||
api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo')
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
|
||||
@@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
@@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
|
||||
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
|
||||
ProviderNotSupportSpeechToTextError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
@@ -49,8 +49,7 @@ class AudioApi(AppApiResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
||||
@@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
|
||||
ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
|
||||
ProviderModelCurrentlyNotSupportError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
@@ -47,7 +48,7 @@ class CompletionApi(AppApiResource):
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
@@ -65,8 +66,7 @@ class CompletionApi(AppApiResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
PubHandler.stop(end_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -98,7 +98,7 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -112,7 +112,7 @@ class ChatApi(AppApiResource):
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -130,8 +130,7 @@ class ChatApi(AppApiResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
PubHandler.stop(end_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
||||
@@ -4,11 +4,11 @@ import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from libs.login import current_user
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource):
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
|
||||
@@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
|
||||
NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from libs.login import current_user
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
@@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
@@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
@@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
@@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
|
||||
@@ -11,8 +11,7 @@ from libs.login import _get_user
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant, TenantAccountJoin, Account
|
||||
from models.model import ApiToken, App
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
def validate_app_token(view=None):
|
||||
def decorator(view):
|
||||
@@ -46,19 +45,19 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
error_msg: str = "You have reached the limit of your subscription."):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
billing_info = BillingService.get_info(api_token.tenant_id)
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
|
||||
members = billing_info['members']
|
||||
apps = billing_info['apps']
|
||||
vector_space = billing_info['vector_space']
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
|
||||
if resource == 'members' and 0 < members['limit'] <= members['size']:
|
||||
if resource == 'members' and 0 < members.limit <= members.size:
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
|
||||
elif resource == 'apps' and 0 < apps.limit <= apps.size:
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
|
||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Unauthorized(error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
||||
@@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
@@ -51,8 +51,7 @@ class AudioApi(WebApiResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
||||
@@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
|
||||
ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
@@ -44,7 +45,7 @@ class CompletionApi(WebApiResource):
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -62,8 +63,7 @@ class CompletionApi(WebApiResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource):
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
PubHandler.stop(end_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -105,7 +105,7 @@ class ChatApi(WebApiResource):
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -123,8 +123,7 @@ class ChatApi(WebApiResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
PubHandler.stop(end_user, task_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
|
||||
@@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
|
||||
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
@@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
|
||||
response = CompletionService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
@@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
@@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.model import Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class AppSiteApi(WebApiResource):
|
||||
@@ -39,6 +43,8 @@ class AppSiteApi(WebApiResource):
|
||||
'site': fields.Nested(site_fields),
|
||||
'model_config': fields.Nested(model_config_fields, allow_null=True),
|
||||
'plan': fields.String,
|
||||
'can_replace_logo': fields.Boolean,
|
||||
'custom_config': fields.Raw(attribute='custom_config'),
|
||||
}
|
||||
|
||||
@marshal_with(app_fields)
|
||||
@@ -50,7 +56,9 @@ class AppSiteApi(WebApiResource):
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id)
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||
|
||||
|
||||
api.add_resource(AppSiteApi, '/site')
|
||||
@@ -59,7 +67,7 @@ api.add_resource(AppSiteApi, '/site')
|
||||
class AppSiteInfo:
|
||||
"""Class to store site information."""
|
||||
|
||||
def __init__(self, tenant, app, site, end_user):
|
||||
def __init__(self, tenant, app, site, end_user, can_replace_logo):
|
||||
"""Initialize AppSiteInfo instance."""
|
||||
self.app_id = app.id
|
||||
self.end_user_id = end_user
|
||||
@@ -67,6 +75,16 @@ class AppSiteInfo:
|
||||
self.site = site
|
||||
self.model_config = None
|
||||
self.plan = tenant.plan
|
||||
self.can_replace_logo = can_replace_logo
|
||||
|
||||
if can_replace_logo:
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
|
||||
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
|
||||
self.custom_config = {
|
||||
'remove_webapp_brand': remove_webapp_brand,
|
||||
'replace_webapp_logo': replace_webapp_logo,
|
||||
}
|
||||
|
||||
if app.enable_site and site.prompt_public:
|
||||
app_model_config = app.app_model_config
|
||||
|
||||
101
api/core/agent/agent/agent_llm_callback.py
Normal file
101
api/core/agent/agent/agent_llm_callback.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLLMCallback(Callback):
|
||||
|
||||
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
|
||||
self.agent_callback = agent_callback
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_before_invoke(
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_after_invoke(
|
||||
result=result
|
||||
)
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
"""
|
||||
self.agent_callback.on_llm_error(
|
||||
error=ex
|
||||
)
|
||||
@@ -1,28 +1,49 @@
|
||||
from typing import List
|
||||
from typing import List, cast
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
return model_instance.get_num_tokens(to_prompt_messages(messages))
|
||||
|
||||
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param llm:
|
||||
:param model_config:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
llm_max_tokens = model_instance.model_rules.max_tokens.max
|
||||
completion_max_tokens = model_instance.model_kwargs.max_tokens
|
||||
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
|
||||
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return 0
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
@@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
model_instance: BaseLLM
|
||||
model_config: ModelConfigEntity
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
raise e
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
@@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
|
||||
@@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
@@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||
@@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage,
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_instance: BaseLLM = None
|
||||
model_instance: BaseLLM
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
model_config: ModelConfigEntity
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
@@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.model_instance.model_kwargs.max_tokens
|
||||
self.model_instance.model_kwargs.max_tokens = 40
|
||||
original_max_tokens = 0
|
||||
for parameter_rule in self.model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
|
||||
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
self.model_config.parameters['max_tokens'] = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
callbacks=None
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
raise e
|
||||
|
||||
function_call = result.function_call
|
||||
self.model_config.parameters['max_tokens'] = original_max_tokens
|
||||
|
||||
self.model_instance.model_kwargs.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
return True if result.message.tool_calls else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
@@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
|
||||
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
@@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
@@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_instance.model_provider.provider_name == 'azure_openai':
|
||||
model = model_instance.base_model_name
|
||||
if model_config.provider == 'azure_openai':
|
||||
model = model_config.model
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_instance.base_model_name
|
||||
model = model_config.credentials.get("base_model_name")
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
# output = ''
|
||||
# rst_json = json.loads(rst)
|
||||
# for item in rst_json:
|
||||
# output += f'{item["content"]}\n'
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
return agent_decision
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -12,9 +12,7 @@ from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
@@ -69,10 +67,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
@@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
@@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
@@ -182,7 +180,7 @@ Thought: {agent_scratchpad}
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@@ -193,7 +191,7 @@ Thought: {agent_scratchpad}
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
@@ -207,7 +205,7 @@ Thought: {agent_scratchpad}
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
@@ -221,7 +219,7 @@ Thought: {agent_scratchpad}
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_instance.model_mode == ModelMode.CHAT:
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
@@ -238,10 +236,16 @@ Thought: {agent_scratchpad}
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
|
||||
@@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage,
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
@@ -54,7 +55,7 @@ Action:
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_instance: BaseLLM = None
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
along with observatons
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
@@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
@@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_instance:
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
@@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
@@ -229,7 +231,7 @@ Thought: {agent_scratchpad}
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
@@ -243,7 +245,7 @@ Thought: {agent_scratchpad}
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
@@ -253,11 +255,12 @@ Thought: {agent_scratchpad}
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_instance.model_mode == ModelMode.CHAT:
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
@@ -275,9 +278,15 @@ Thought: {agent_scratchpad}
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import Union, Optional
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
@@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||
from core.helper import moderation
|
||||
from core.model_providers.error import LLMError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
@@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum):
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
model_instance: BaseLLM
|
||||
model_config: ModelConfigEntity
|
||||
tools: list[BaseTool]
|
||||
summary_model_instance: BaseLLM = None
|
||||
memory: Optional[BaseChatMemory] = None
|
||||
summary_model_config: Optional[ModelConfigEntity] = None
|
||||
memory: Optional[TokenBufferMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
max_execution_time: Optional[float] = None
|
||||
early_stopping_method: str = "generate"
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
|
||||
class Config:
|
||||
@@ -62,34 +65,42 @@ class AgentExecutor:
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
summary_model_config=self.configuration.summary_model_config
|
||||
if self.configuration.summary_model_config else None,
|
||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||
if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_model_config=self.configuration.summary_model_config
|
||||
if self.configuration.summary_model_config else None,
|
||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
self.configuration.tools = [t for t in self.configuration.tools
|
||||
if isinstance(t, DatasetRetrieverTool)
|
||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||
if self.configuration.memory else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
self.configuration.tools = [t for t in self.configuration.tools
|
||||
if isinstance(t, DatasetRetrieverTool)
|
||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
verbose=True
|
||||
@@ -104,11 +115,11 @@ class AgentExecutor:
|
||||
|
||||
def run(self, query: str) -> AgentExecuteResult:
|
||||
moderation_result = moderation.check_moderation(
|
||||
self.configuration.model_instance.model_provider,
|
||||
self.configuration.model_config,
|
||||
query
|
||||
)
|
||||
|
||||
if not moderation_result:
|
||||
if moderation_result:
|
||||
return AgentExecuteResult(
|
||||
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
strategy=self.configuration.strategy,
|
||||
@@ -118,7 +129,6 @@ class AgentExecutor:
|
||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||
agent=self.agent,
|
||||
tools=self.configuration.tools,
|
||||
memory=self.configuration.memory,
|
||||
max_iterations=self.configuration.max_iterations,
|
||||
max_execution_time=self.configuration.max_execution_time,
|
||||
early_stopping_method=self.configuration.early_stopping_method,
|
||||
@@ -126,8 +136,8 @@ class AgentExecutor:
|
||||
)
|
||||
|
||||
try:
|
||||
output = agent_executor.run(query)
|
||||
except LLMError as ex:
|
||||
output = agent_executor.run(input=query)
|
||||
except InvokeError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logging.exception("agent_executor run failed")
|
||||
|
||||
251
api/core/app_runner/agent_app_runner.py
Normal file
251
api/core/app_runner/agent_app_runner.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.features.agent_runner import AgentRunnerFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentApplicationRunner(AppRunner):
|
||||
"""
|
||||
Agent Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run agent application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Create MessageChain
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message,
|
||||
queue_manager=queue_manager,
|
||||
message_chain=message_chain
|
||||
)
|
||||
|
||||
# init LLM Callback
|
||||
agent_llm_callback = AgentLLMCallback(
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
agent_runner = AgentRunnerFeature(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=app_orchestration_config.agent,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
callback=agent_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# agent run
|
||||
result = agent_runner.run(
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if result:
|
||||
self._save_message_chain(
|
||||
message_chain=message_chain,
|
||||
output_text=result
|
||||
)
|
||||
|
||||
if (result
|
||||
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
|
||||
and app_orchestration_config.prompt_template.simple_prompt_template
|
||||
):
|
||||
# Direct output if agent result exists and has pre prompt
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
stream=application_generate_entity.stream,
|
||||
text=result,
|
||||
usage=self._get_usage_of_all_agent_thoughts(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# As normal LLM run, agent result as context
|
||||
context = result
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
||||
267
api/core/app_runner/app_runner.py
Normal file
267
api/core/app_runner/app_runner.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import time
|
||||
from typing import cast, Optional, List, Tuple, Generator, Union
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
query: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
:param app_record: app record
|
||||
:param model_config: model config entity
|
||||
:param prompt_template_entity: prompt template entity
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt messages without memory and context
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: List[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
||||
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
def originze_prompt_messages(self, app_record: App,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
:param app_record: app record
|
||||
:param model_config: model config entity
|
||||
:param prompt_template_entity: prompt template entity
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: query
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get prompt without memory and context
|
||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
prompt_messages, stop = prompt_transform.get_prompt(
|
||||
app_mode=app_record.mode,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=app_record.mode,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def direct_output(self, queue_manager: ApplicationQueueManager,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param prompt_messages: prompt messages
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:param usage: usage
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish_chunk_message(LLMResultChunk(
|
||||
model=app_orchestration_config.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=token)
|
||||
)
|
||||
))
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=LLMResult(
|
||||
model=app_orchestration_config.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: ApplicationQueueManager,
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
else:
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=invoke_result
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
queue_manager.publish_chunk_message(result)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
llm_result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage
|
||||
)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=llm_result
|
||||
)
|
||||
363
api/core/app_runner/basic_app_runner.py
Normal file
363
api/core/app_runner/basic_app_runner.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import logging
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.moderation.base import ModerationException
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasicApplicationRunner(AppRunner):
|
||||
"""
|
||||
Basic Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
if query:
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id
|
||||
)
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_orchestration_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
if app_orchestration_config.dataset:
|
||||
context = self.retrieve_dataset_context(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_record=app_record,
|
||||
queue_manager=queue_manager,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
show_retrieve_source=app_orchestration_config.show_retrieve_source,
|
||||
dataset_config=app_orchestration_config.dataset,
|
||||
message=message,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config_entity: app orchestration config entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = ModerationFeature()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
external_data_fetch_feature = ExternalDataFetchFeature()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
def retrieve_dataset_context(self, tenant_id: str,
|
||||
app_record: App,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
model_config: ModelConfigEntity,
|
||||
dataset_config: DatasetEntity,
|
||||
show_retrieve_source: bool,
|
||||
message: Message,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset context
|
||||
:param tenant_id: tenant id
|
||||
:param app_record: app record
|
||||
:param queue_manager: queue manager
|
||||
:param model_config: model config
|
||||
:param dataset_config: dataset config
|
||||
:param show_retrieve_source: show retrieve source
|
||||
:param message: message
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
app_record.id,
|
||||
message.id,
|
||||
user_id,
|
||||
invoke_from
|
||||
)
|
||||
|
||||
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
|
||||
and dataset_config.retrieve_config.query_variable):
|
||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||
|
||||
dataset_retrieval = DatasetRetrievalFeature()
|
||||
return dataset_retrieval.retrieve(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
config=dataset_config,
|
||||
query=query,
|
||||
invoke_from=invoke_from,
|
||||
show_retrieve_source=show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
483
api/core/app_runner/generate_task_pipeline.py
Normal file
483
api/core/app_runner/generate_task_pipeline.py
Normal file
@@ -0,0 +1,483 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Union, Generator, cast, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
|
||||
from core.entities.application_entities import ApplicationGenerateEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
|
||||
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
|
||||
AnnotationReplyEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
|
||||
TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
|
||||
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, Conversation, MessageAgentThought
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
llm_result: LLMResult
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class GenerateTaskPipeline:
|
||||
"""
|
||||
GenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._task_state = TaskState(
|
||||
llm_result=LLMResult(
|
||||
model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
)
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
|
||||
def process(self, stream: bool) -> Union[dict, Generator]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
return self._process_blocking_response()
|
||||
|
||||
def _process_blocking_response(self) -> dict:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, AnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.llm_result.message.content,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
# Save message
|
||||
self._save_message(event.llm_result)
|
||||
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._task_state.metadata
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
def _process_stream_response(self) -> Generator:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.llm_result.message.content,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
replace_response = {
|
||||
'event': 'message_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': self._task_state.llm_result.message.content,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
replace_response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(replace_response)
|
||||
|
||||
# Save message
|
||||
self._save_message(self._task_state.llm_result)
|
||||
|
||||
response = {
|
||||
'event': 'message_end',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._task_state.metadata
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, AnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
agent_thought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
'event': 'agent_thought',
|
||||
'id': agent_thought.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'position': agent_thought.position,
|
||||
'thought': agent_thought.thought,
|
||||
'tool': agent_thought.tool,
|
||||
'tool_input': agent_thought.tool_input,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish_chunk_message(LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
))
|
||||
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
|
||||
continue
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
response = self._handle_chunk(delta_text)
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
response = {
|
||||
'event': 'message_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': event.text,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
continue
|
||||
|
||||
def _save_message(self, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:param llm_result: llm result
|
||||
:return:
|
||||
"""
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
|
||||
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
|
||||
if llm_result.message.content else ''
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
||||
def _handle_chunk(self, text: str) -> dict:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'id': self._message.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': text,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
return response
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent) -> Exception:
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
return InvokeAuthorizationError('Incorrect API key provided')
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
return e
|
||||
else:
|
||||
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
def _yield_response(self, response: dict) -> str:
|
||||
"""
|
||||
Yield response.
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return "data: " + json.dumps(response) + "\n\n"
|
||||
|
||||
def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = 'user'
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = 'assistant'
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = 'system'
|
||||
else:
|
||||
continue
|
||||
|
||||
text = ''
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
prompts.append({
|
||||
"role": role,
|
||||
"text": text,
|
||||
"files": files
|
||||
})
|
||||
else:
|
||||
prompts.append({
|
||||
"role": 'user',
|
||||
"text": prompt_messages[0].content
|
||||
})
|
||||
|
||||
return prompts
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
"""
|
||||
app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
|
||||
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
|
||||
|
||||
if sensitive_word_avoidance:
|
||||
return OutputModerationHandler(
|
||||
tenant_id=self._application_generate_entity.tenant_id,
|
||||
app_id=self._application_generate_entity.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
on_message_replace_func=self._queue_manager.publish_message_replace
|
||||
)
|
||||
138
api/core/app_runner/moderation_handler.py
Normal file
138
api/core/app_runner/moderation_handler.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from flask import current_app, Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
||||
655
api/core/application_manager.py
Normal file
655
api/core/application_manager.py
Normal file
@@ -0,0 +1,655 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import cast, Optional, Any, Union, Generator, Tuple
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app_runner.agent_app_runner import AgentApplicationRunner
|
||||
from core.app_runner.basic_app_runner import BasicApplicationRunner
|
||||
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
|
||||
ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
|
||||
AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
|
||||
AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file.file_obj import FileObj
|
||||
from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Conversation, Message, MessageFile, App
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApplicationManager:
|
||||
"""
|
||||
This class is responsible for managing application
|
||||
"""
|
||||
|
||||
def generate(self, tenant_id: str,
|
||||
app_id: str,
|
||||
app_model_config_id: str,
|
||||
app_model_config_dict: dict,
|
||||
app_model_config_override: bool,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
inputs: dict[str, str],
|
||||
query: Optional[str] = None,
|
||||
files: Optional[list[FileObj]] = None,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = False,
|
||||
extras: Optional[dict[str, Any]] = None) \
|
||||
-> Union[dict, Generator]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param tenant_id: workspace ID
|
||||
:param app_id: app ID
|
||||
:param app_model_config_id: app model config id
|
||||
:param app_model_config_dict: app model config dict
|
||||
:param app_model_config_override: app model config override
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param files: file obj list
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
:param extras: extras
|
||||
"""
|
||||
# init task id
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ApplicationGenerateEntity(
|
||||
task_id=task_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
app_model_config_dict=app_model_config_dict,
|
||||
app_orchestration_config_entity=self._convert_from_app_model_config_dict(
|
||||
tenant_id=tenant_id,
|
||||
app_model_config_dict=app_model_config_dict
|
||||
),
|
||||
app_model_config_override=app_model_config_override,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else inputs,
|
||||
query=query.replace('\x00', '') if query else None,
|
||||
files=files if files else [],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = ApplicationQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
return self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation_id: conversation ID
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
if application_generate_entity.app_orchestration_config_entity.agent:
|
||||
# agent app
|
||||
runner = AgentApplicationRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
else:
|
||||
# basic app
|
||||
runner = BasicApplicationRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e)
|
||||
except (ValueError, InvokeError) as e:
|
||||
queue_manager.publish_error(e)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
stream: bool = False) -> Union[dict, Generator]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:param stream: is stream
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = GenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process(stream=stream)
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise ConversationTaskStoppedException()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||
-> AppOrchestrationConfigEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param tenant_id: tenant ID
|
||||
:param app_model_config_dict: app model config dict
|
||||
:raises ProviderTokenNotInitError: provider token not init error
|
||||
:return: app orchestration config entity
|
||||
"""
|
||||
properties = {}
|
||||
|
||||
copy_app_model_config_dict = app_model_config_dict.copy()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=copy_app_model_config_dict['model']['provider'],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_name = copy_app_model_config_dict['model']['name']
|
||||
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# check model credentials
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=copy_app_model_config_dict['model']['name']
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
model_name = copy_app_model_config_dict['model']['name']
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = copy_app_model_config_dict['model'].get('completion_params')
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = copy_app_model_config_dict['model'].get('mode')
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
credentials=model_credentials
|
||||
)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
copy_app_model_config_dict['model']['name'],
|
||||
model_credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
properties['model_config'] = ModelConfigEntity(
|
||||
provider=copy_app_model_config_dict['model']['provider'],
|
||||
model=copy_app_model_config_dict['model']['name'],
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
# prompt template
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
|
||||
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
|
||||
properties['prompt_template'] = PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
simple_prompt_template=simple_prompt_template
|
||||
)
|
||||
else:
|
||||
advanced_chat_prompt_template = None
|
||||
chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append({
|
||||
"text": message["text"],
|
||||
"role": PromptMessageRole.value_of(message["role"])
|
||||
})
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
||||
messages=chat_prompt_messages
|
||||
)
|
||||
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
'prompt': completion_prompt_config['prompt']['text'],
|
||||
}
|
||||
|
||||
if 'conversation_histories_role' in completion_prompt_config:
|
||||
completion_prompt_template_params['role_prefix'] = {
|
||||
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
||||
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
**completion_prompt_template_params
|
||||
)
|
||||
|
||||
properties['prompt_template'] = PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template
|
||||
)
|
||||
|
||||
# external data variables
|
||||
properties['external_data_variables'] = []
|
||||
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
continue
|
||||
|
||||
properties['external_data_variables'].append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=external_data_tool['variable'],
|
||||
type=external_data_tool['type'],
|
||||
config=external_data_tool['config']
|
||||
)
|
||||
)
|
||||
|
||||
# show retrieve source
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
|
||||
if retriever_resource_dict:
|
||||
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
|
||||
show_retrieve_source = True
|
||||
|
||||
properties['show_retrieve_source'] = show_retrieve_source
|
||||
|
||||
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
|
||||
'enabled']:
|
||||
agent_dict = copy_app_model_config_dict.get('agent_mode')
|
||||
if agent_dict['strategy'] in ['router', 'react_router']:
|
||||
dataset_ids = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
||||
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
single_strategy=agent_dict['strategy']
|
||||
)
|
||||
)
|
||||
else:
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
top_k=dataset_configs.get('top_k'),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model')
|
||||
)
|
||||
)
|
||||
else:
|
||||
if agent_dict['strategy'] == 'react':
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
key = list(tool.keys())[0]
|
||||
tool_item = tool[key]
|
||||
|
||||
agent_tool_properties = {
|
||||
"tool_id": key
|
||||
}
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties["config"] = tool_item
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
||||
properties['agent'] = AgentEntity(
|
||||
provider=properties['model_config'].provider,
|
||||
model=properties['model_config'].model,
|
||||
strategy=strategy,
|
||||
tools=agent_tools
|
||||
)
|
||||
|
||||
# file upload
|
||||
file_upload_dict = copy_app_model_config_dict.get('file_upload')
|
||||
if file_upload_dict:
|
||||
if 'image' in file_upload_dict and file_upload_dict['image']:
|
||||
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||
properties['file_upload'] = FileUploadEntity(
|
||||
image_config={
|
||||
'number_limits': file_upload_dict['image']['number_limits'],
|
||||
'detail': file_upload_dict['image']['detail'],
|
||||
'transfer_methods': file_upload_dict['image']['transfer_methods']
|
||||
}
|
||||
)
|
||||
|
||||
# opening statement
|
||||
properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
|
||||
|
||||
# suggested questions after answer
|
||||
suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
|
||||
if suggested_questions_after_answer_dict:
|
||||
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
|
||||
properties['suggested_questions_after_answer'] = True
|
||||
|
||||
# more like this
|
||||
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
|
||||
if more_like_this_dict:
|
||||
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
|
||||
properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
|
||||
|
||||
# speech to text
|
||||
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
|
||||
if speech_to_text_dict:
|
||||
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
|
||||
properties['speech_to_text'] = True
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
if sensitive_word_avoidance_dict:
|
||||
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
|
||||
properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get('type'),
|
||||
config=sensitive_word_avoidance_dict.get('config'),
|
||||
)
|
||||
|
||||
return AppOrchestrationConfigEntity(**properties)
|
||||
|
||||
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||
-> Tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
:return:
|
||||
"""
|
||||
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=app_orchestration_config_entity.model_config.model,
|
||||
credentials=app_orchestration_config_entity.model_config.credentials
|
||||
)
|
||||
|
||||
app_record = (db.session.query(App)
|
||||
.filter(App.id == application_generate_entity.app_id).first())
|
||||
|
||||
app_mode = app_record.mode
|
||||
|
||||
# get from source
|
||||
end_user_id = None
|
||||
account_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
from_source = 'api'
|
||||
end_user_id = application_generate_entity.user_id
|
||||
else:
|
||||
from_source = 'console'
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
override_model_configs = None
|
||||
if application_generate_entity.app_model_config_override:
|
||||
override_model_configs = application_generate_entity.app_model_config_dict
|
||||
|
||||
introduction = ''
|
||||
if app_mode == 'chat':
|
||||
# get conversation introduction
|
||||
introduction = self._get_conversation_introduction(application_generate_entity)
|
||||
|
||||
if not application_generate_entity.conversation_id:
|
||||
conversation = Conversation(
|
||||
app_id=app_record.id,
|
||||
app_model_config_id=application_generate_entity.app_model_config_id,
|
||||
model_provider=app_orchestration_config_entity.model_config.provider,
|
||||
model_id=app_orchestration_config_entity.model_config.model,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_mode,
|
||||
name='New conversation',
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status='normal',
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
else:
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == application_generate_entity.conversation_id,
|
||||
Conversation.app_id == app_record.id
|
||||
).first()
|
||||
)
|
||||
|
||||
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
|
||||
|
||||
message = Message(
|
||||
app_id=app_record.id,
|
||||
model_provider=app_orchestration_config_entity.model_config.provider,
|
||||
model_id=app_orchestration_config_entity.model_config.model,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=currency,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
agent_based=app_orchestration_config_entity.agent is not None
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if account_id else 'end_user'),
|
||||
created_by=account_id or end_user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
return conversation, message
|
||||
|
||||
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
|
||||
"""
|
||||
Get conversation introduction
|
||||
:param application_generate_entity: application generate entity
|
||||
:return: conversation introduction
|
||||
"""
|
||||
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
||||
introduction = app_orchestration_config_entity.opening_statement
|
||||
|
||||
if introduction:
|
||||
try:
|
||||
inputs = application_generate_entity.inputs
|
||||
prompt_template = PromptTemplateParser(template=introduction)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
introduction = prompt_template.format(prompt_inputs)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return introduction
|
||||
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
"""
|
||||
Get message by message id
|
||||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.id == message_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return message
|
||||
228
api/core/application_queue_manager.py
Normal file
228
api/core/application_queue_manager.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import queue
|
||||
import time
|
||||
from typing import Generator, Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
|
||||
QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
|
||||
QueueMessageEvent, QueueMessage, AnnotationReplyEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAgentThought
|
||||
|
||||
|
||||
class ApplicationQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
conversation_id: str,
|
||||
app_mode: str,
|
||||
message_id: str) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
self._task_id = task_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
self._conversation_id = str(conversation_id)
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
"""
|
||||
# wait for 10 minutes to stop listen
|
||||
listen_timeout = 600
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
if message is None:
|
||||
break
|
||||
|
||||
yield message
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
|
||||
self.stop_listen()
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
self.publish(QueuePingEvent())
|
||||
last_ping_time = elapsed_time // 10
|
||||
|
||||
def stop_listen(self) -> None:
|
||||
"""
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._q.put(None)
|
||||
|
||||
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
|
||||
"""
|
||||
Publish chunk message to channel
|
||||
|
||||
:param chunk: chunk
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEvent(
|
||||
chunk=chunk
|
||||
))
|
||||
|
||||
def publish_message_replace(self, text: str) -> None:
|
||||
"""
|
||||
Publish message replace
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageReplaceEvent(
|
||||
text=text
|
||||
))
|
||||
|
||||
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
|
||||
"""
|
||||
Publish retriever resources
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
|
||||
|
||||
def publish_annotation_reply(self, message_annotation_id: str) -> None:
|
||||
"""
|
||||
Publish annotation reply
|
||||
:param message_annotation_id: message annotation id
|
||||
:return:
|
||||
"""
|
||||
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
|
||||
|
||||
def publish_message_end(self, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Publish message end
|
||||
:param llm_result: llm result
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEndEvent(llm_result=llm_result))
|
||||
self.stop_listen()
|
||||
|
||||
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_agent_thought: message agent thought
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=message_agent_thought.id
|
||||
))
|
||||
|
||||
def publish_error(self, e) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
))
|
||||
self.stop_listen()
|
||||
|
||||
def publish(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:return:
|
||||
"""
|
||||
self._check_for_sqlalchemy_models(event.dict())
|
||||
|
||||
message = QueueMessage(
|
||||
task_id=self._task_id,
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
self.stop_listen()
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
:return:
|
||||
"""
|
||||
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
|
||||
result = redis_client.get(stopped_cache_key)
|
||||
if result is not None:
|
||||
redis_client.delete(stopped_cache_key)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate task belong cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_belong:{task_id}"
|
||||
|
||||
@classmethod
|
||||
def _generate_stopped_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate stopped cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_stopped:{task_id}"
|
||||
|
||||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed.")
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
@@ -2,30 +2,40 @@ import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Union, Optional, cast
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageChain, MessageAgentThought, Message
|
||||
|
||||
|
||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||
def __init__(self, model_config: ModelConfigEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
message_chain: MessageChain) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.model_instance = model_instance
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.model_config = model_config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.message_chain = message_chain
|
||||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||
self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
self.current_chain = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
@@ -46,65 +56,60 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
|
||||
def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
if result.usage:
|
||||
self._current_loop.prompt_tokens = result.usage.prompt_tokens
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
|
||||
model=self.model_config.model,
|
||||
credentials=self.model_config.credentials,
|
||||
prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
|
||||
completion_message = result.message
|
||||
if completion_message.tool_calls:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.tool_calls})
|
||||
else:
|
||||
self._current_loop.completion = completion_message.content
|
||||
|
||||
if result.usage:
|
||||
self._current_loop.completion_tokens = result.usage.completion_tokens
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
|
||||
model=self.model_config.model,
|
||||
credentials=self.model_config.credentials,
|
||||
prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt="\n".join([message.content for message in messages[0]]),
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
# serialized={'name': 'OpenAI'}
|
||||
# prompts=['Answer the following questions...\nThought:']
|
||||
# kwargs={}
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt=prompts[0],
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
# kwargs={}
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
if 'function_call' in completion_message.additional_kwargs:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
|
||||
else:
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
else:
|
||||
self._current_loop.completion = completion_generation.text
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
@@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
if completion is not None:
|
||||
self._current_loop.completion = completion
|
||||
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
self._message_agent_thought = self._init_agent_thought()
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
@@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
self._complete_agent_thought(self._message_agent_thought)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
@@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
self._current_loop.thought = '[DONE]'
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
self._message_agent_thought = self._init_agent_thought()
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
self._complete_agent_thought(self._message_agent_thought)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
elif not self._current_loop and self._agent_loops:
|
||||
self._agent_loops[-1].status = 'agent_finish'
|
||||
|
||||
def _init_agent_thought(self) -> MessageAgentThought:
|
||||
message_agent_thought = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=self.message_chain.id,
|
||||
position=self._current_loop.position,
|
||||
thought=self._current_loop.thought,
|
||||
tool=self._current_loop.tool_name,
|
||||
tool_input=self._current_loop.tool_input,
|
||||
message=self._current_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=self._current_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
|
||||
created_by=(self.message.from_account_id
|
||||
if self.message.from_source == 'console' else self.message.from_end_user_id)
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.commit()
|
||||
|
||||
self.queue_manager.publish_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
||||
loop_message_tokens = self._current_loop.prompt_tokens
|
||||
loop_answer_tokens = self._current_loop.completion_tokens
|
||||
|
||||
# transform usage
|
||||
llm_usage = self.model_type_instance._calc_response_usage(
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
loop_message_tokens,
|
||||
loop_answer_tokens
|
||||
)
|
||||
|
||||
message_agent_thought.observation = self._current_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
message_agent_thought.latency = self._current_loop.latency
|
||||
message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
|
||||
message_agent_thought.total_price = llm_usage.total_price
|
||||
message_agent_thought.currency = llm_usage.currency
|
||||
db.session.commit()
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.queries = []
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_name: str = serialized.get('name')
|
||||
dataset_id = tool_name.removeprefix('dataset-')
|
||||
|
||||
try:
|
||||
input_dict = json.loads(input_str.replace("'", "\""))
|
||||
query = input_dict.get('query')
|
||||
except JSONDecodeError:
|
||||
query = input_str
|
||||
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.debug("Dataset tool on_llm_error: %s", error)
|
||||
@@ -1,16 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChainResult(BaseModel):
|
||||
type: str = None
|
||||
prompt: dict = None
|
||||
completion: dict = None
|
||||
|
||||
status: str = 'chain_started'
|
||||
completed: bool = False
|
||||
|
||||
started_at: float = None
|
||||
completed_at: float = None
|
||||
|
||||
agent_result: dict = None
|
||||
"""only when type is 'AgentExecutor'"""
|
||||
@@ -1,6 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DatasetQueryObj(BaseModel):
|
||||
dataset_id: str = None
|
||||
query: str = None
|
||||
@@ -1,8 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
prompt: str = ''
|
||||
prompt_tokens: int = 0
|
||||
completion: str = ''
|
||||
completion_tokens: int = 0
|
||||
@@ -1,17 +1,44 @@
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment
|
||||
from models.dataset import DocumentSegment, DatasetQuery
|
||||
from models.model import DatasetRetrieverResource
|
||||
|
||||
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
self.conversation_message_task = conversation_message_task
|
||||
def __init__(self, queue_manager: ApplicationQueueManager,
|
||||
app_id: str,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
self._queue_manager = queue_manager
|
||||
self._app_id = app_id
|
||||
self._message_id = message_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
|
||||
def on_query(self, query: str, dataset_id: str) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=('account'
|
||||
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self._user_id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
@@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler:
|
||||
|
||||
def return_retriever_resource_info(self, resource: List):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
self.conversation_message_task.on_dataset_query_finish(resource)
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self._message_id,
|
||||
position=item.get('position'),
|
||||
dataset_id=item.get('dataset_id'),
|
||||
dataset_name=item.get('dataset_name'),
|
||||
document_id=item.get('document_id'),
|
||||
document_name=item.get('document_name'),
|
||||
data_source_type=item.get('data_source_type'),
|
||||
segment_id=item.get('segment_id'),
|
||||
score=item.get('score') if 'score' in item else None,
|
||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||
content=item.get('content'),
|
||||
retriever_from=item.get('retriever_from'),
|
||||
created_by=self._user_id
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
|
||||
self._queue_manager.publish_retriever_resources(resource)
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult, BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
|
||||
ImagePromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.moderation.base import ModerationOutputsResult, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_instance: BaseLLM,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.model_instance = model_instance
|
||||
self.llm_message = LLMMessage()
|
||||
self.start_at = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
self.output_moderation_handler = None
|
||||
self.init_output_moderation()
|
||||
|
||||
def init_output_moderation(self):
|
||||
app_model_config = self.conversation_message_task.app_model_config
|
||||
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
|
||||
|
||||
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
|
||||
self.output_moderation_handler = OutputModerationHandler(
|
||||
tenant_id=self.conversation_message_task.tenant_id,
|
||||
app_id=self.conversation_message_task.app.id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config")
|
||||
),
|
||||
on_message_replace_func=self.conversation_message_task.on_message_replace
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
real_prompts = []
|
||||
for message in messages[0]:
|
||||
if message.type == 'human':
|
||||
role = 'user'
|
||||
elif message.type == 'ai':
|
||||
role = 'assistant'
|
||||
else:
|
||||
role = 'system'
|
||||
|
||||
real_prompts.append({
|
||||
"role": role,
|
||||
"text": message.content,
|
||||
"files": [{
|
||||
"type": file.type.value,
|
||||
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
|
||||
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
|
||||
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
}]
|
||||
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
|
||||
completion=response.generations[0][0].text,
|
||||
public_event=True if self.conversation_message_task.streaming else False
|
||||
)
|
||||
else:
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(self.llm_message.completion)
|
||||
|
||||
if response.llm_output and 'token_usage' in response.llm_output:
|
||||
if 'prompt_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
|
||||
if 'completion_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
ex = ConversationTaskInterruptException()
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
try:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
self.llm_message.completion += token
|
||||
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.append_new_token(token)
|
||||
except ConversationTaskStoppedException as ex:
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
if isinstance(error, ConversationTaskInterruptException):
|
||||
self.llm_message.completion = self.output_moderation_handler.get_final_output()
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message)
|
||||
else:
|
||||
logging.debug("on_llm_error: %s", error)
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
||||
@@ -1,76 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.agent_callback = None
|
||||
|
||||
def clear_chain_results(self) -> None:
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return True
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
if not self._current_chain_result:
|
||||
chain_type = serialized['id'][-1]
|
||||
if chain_type:
|
||||
self._current_chain_result = ChainResult(
|
||||
type=chain_type,
|
||||
prompt=inputs,
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = self._current_chain_message
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
|
||||
self._current_chain_result.status = 'chain_ended'
|
||||
self._current_chain_result.completion = outputs
|
||||
self._current_chain_result.completed = True
|
||||
self._current_chain_result.completed_at = time.perf_counter()
|
||||
|
||||
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
|
||||
|
||||
self.clear_chain_results()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.debug("Dataset tool on_chain_error: %s", error)
|
||||
self.clear_chain_results()
|
||||
@@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
try:
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
except ValueError:
|
||||
thought = ''
|
||||
|
||||
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
|
||||
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
|
||||
|
||||
@@ -5,15 +5,19 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import LLMResult, Generation
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
model_instance: BaseLLM
|
||||
model_config: ModelConfigEntity
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
parameters: Dict[str, Any] = {}
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -23,14 +27,23 @@ class LLMChain(LCLLMChain):
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
messages = prompts[0].to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
stream=False,
|
||||
stop=stop,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
|
||||
model_parameters=self.parameters
|
||||
)
|
||||
|
||||
generations = [
|
||||
[Generation(text=result.content)]
|
||||
[Generation(text=result.message.content)]
|
||||
]
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
@@ -1,417 +0,0 @@
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from flask import current_app, Flask
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
from core.moderation.base import ModerationException, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
||||
auto_generate_name: bool = True):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
query = PromptTemplateParser.remove_template_variables(query)
|
||||
|
||||
memory = None
|
||||
if conversation:
|
||||
# get memory of conversation (read-only)
|
||||
memory = cls.get_memory_from_conversation(
|
||||
tenant_id=app.tenant_id,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
return_messages=False
|
||||
)
|
||||
|
||||
inputs = conversation.inputs
|
||||
|
||||
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=app.tenant_id,
|
||||
model_config=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
app=app,
|
||||
app_model_config=app_model_config,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
is_override=is_override,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance,
|
||||
auto_generate_name=auto_generate_name
|
||||
)
|
||||
|
||||
prompt_message_files = [file.prompt_message_file for file in files]
|
||||
|
||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||
mode=app.mode,
|
||||
model_instance=final_model_instance,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files
|
||||
)
|
||||
|
||||
# init orchestrator rule parser
|
||||
orchestrator_rule_parser = OrchestratorRuleParser(
|
||||
tenant_id=app.tenant_id,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
try:
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
|
||||
except ModerationException as e:
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=str(e)
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_model_config.external_data_tools_list
|
||||
if external_data_tools:
|
||||
inputs = cls.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback,
|
||||
tenant_id=app.tenant_id,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if query_for_agent and agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query_for_agent)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query_for_agent)
|
||||
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
fake_response = None
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
|
||||
PlanningStrategy.REACT_ROUTER]:
|
||||
fake_response = agent_execute_result.output
|
||||
|
||||
# run the final llm
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=fake_response
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
return
|
||||
except ChunkedEncodingError as e:
|
||||
# Interrupt by LLM (like OpenAI), handle it.
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
conversation_message_task.end()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
|
||||
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
|
||||
return inputs, query
|
||||
|
||||
type = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
|
||||
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation_result = moderation.moderation_for_inputs(inputs, query)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
return inputs, query
|
||||
|
||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
raise ModerationException(moderation_result.preset_response)
|
||||
elif moderation_result.action == ModerationAction.OVERRIDED:
|
||||
inputs = moderation_result.inputs
|
||||
query = moderation_result.query
|
||||
|
||||
return inputs, query
|
||||
|
||||
@classmethod
|
||||
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
|
||||
inputs: dict, query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
# Group tools by type and config
|
||||
grouped_tools = {}
|
||||
for tool in external_data_tools:
|
||||
if not tool.get("enabled"):
|
||||
continue
|
||||
|
||||
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
|
||||
grouped_tools.setdefault(tool_key, []).append(tool)
|
||||
|
||||
results = {}
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = {}
|
||||
for tool in external_data_tools:
|
||||
if not tool.get("enabled"):
|
||||
continue
|
||||
|
||||
future = executor.submit(
|
||||
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
|
||||
inputs, query
|
||||
)
|
||||
|
||||
futures[future] = tool
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
tool_variable, result = future.result()
|
||||
results[tool_variable] = result
|
||||
|
||||
inputs.update(results)
|
||||
return inputs
|
||||
|
||||
@classmethod
|
||||
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
|
||||
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
with flask_app.app_context():
|
||||
tool_variable = external_data_tool.get("variable")
|
||||
tool_type = external_data_tool.get("type")
|
||||
tool_config = external_data_tool.get("config")
|
||||
|
||||
external_data_tool_factory = ExternalDataToolFactory(
|
||||
name=tool_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=tool_variable,
|
||||
config=tool_config
|
||||
)
|
||||
|
||||
# query external data tool
|
||||
result = external_data_tool_factory.query(
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
return tool_variable, result
|
||||
|
||||
@classmethod
|
||||
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
|
||||
if app.mode != 'completion':
|
||||
return query
|
||||
|
||||
return inputs.get(app_model_config.dataset_query_variable, "")
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
fake_response: Optional[str]):
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get llm prompt
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, stop_words = prompt_transform.get_prompt(
|
||||
app_mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
model_config = app_model_config.model_dict
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
stop_words = completion_params.get("stop", [])
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
response = model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop_words if stop_words else None,
|
||||
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
|
||||
fake_response=fake_response
|
||||
)
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
@classmethod
|
||||
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
|
||||
conversation: Conversation,
|
||||
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
|
||||
# only for calc token in memory
|
||||
memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=tenant_id,
|
||||
model_config=app_model_config.model_dict
|
||||
)
|
||||
|
||||
# use llm config from conversation
|
||||
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
|
||||
conversation=conversation,
|
||||
model_instance=memory_model_instance,
|
||||
max_token_limit=kwargs.get("max_token_limit", 2048),
|
||||
memory_key=kwargs.get("memory_key", "chat_history"),
|
||||
return_messages=kwargs.get("return_messages", True),
|
||||
input_key=kwargs.get("input_key", "input"),
|
||||
output_key=kwargs.get("output_key", "output"),
|
||||
message_limit=kwargs.get("message_limit", 10),
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
if model_limited_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get prompt without memory and context
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, _ = prompt_transform.get_prompt(
|
||||
app_mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
||||
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
|
||||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
if model_limited_tokens is None:
|
||||
return
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||
|
||||
# update model instance max tokens
|
||||
model_kwargs = model_instance.get_model_kwargs()
|
||||
model_kwargs.max_tokens = max_tokens
|
||||
model_instance.set_model_kwargs(model_kwargs)
|
||||
@@ -1,489 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DatasetQuery
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
|
||||
MessageChain, DatasetRetrieverResource, MessageFile
|
||||
|
||||
|
||||
class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, files: List[FileObj], streaming: bool,
|
||||
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
|
||||
auto_generate_name: bool = True):
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.task_id = task_id
|
||||
|
||||
self.app = app
|
||||
self.tenant_id = app.tenant_id
|
||||
self.app_model_config = app_model_config
|
||||
self.is_override = is_override
|
||||
|
||||
self.user = user
|
||||
self.inputs = inputs
|
||||
self.query = query
|
||||
self.files = files
|
||||
self.streaming = streaming
|
||||
|
||||
self.conversation = conversation
|
||||
self.is_new_conversation = False
|
||||
|
||||
self.model_instance = model_instance
|
||||
|
||||
self.message = None
|
||||
|
||||
self.retriever_resource = None
|
||||
self.auto_generate_name = auto_generate_name
|
||||
|
||||
self.model_dict = self.app_model_config.model_dict
|
||||
self.provider_name = self.model_dict.get('provider')
|
||||
self.model_name = self.model_dict.get('name')
|
||||
self.mode = app.mode
|
||||
|
||||
self.init()
|
||||
|
||||
self._pub_handler = PubHandler(
|
||||
user=self.user,
|
||||
task_id=self.task_id,
|
||||
message=self.message,
|
||||
conversation=self.conversation,
|
||||
chain_pub=False, # disabled currently
|
||||
agent_thought_pub=True
|
||||
)
|
||||
|
||||
def init(self):
|
||||
|
||||
override_model_configs = None
|
||||
if self.is_override:
|
||||
override_model_configs = self.app_model_config.to_dict()
|
||||
|
||||
introduction = ''
|
||||
system_instruction = ''
|
||||
system_instruction_tokens = 0
|
||||
if self.mode == 'chat':
|
||||
introduction = self.app_model_config.opening_statement
|
||||
if introduction:
|
||||
prompt_template = PromptTemplateParser(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
|
||||
try:
|
||||
introduction = prompt_template.format(prompt_inputs)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if self.app_model_config.pre_prompt:
|
||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
||||
system_instruction = system_message.content
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=self.provider_name,
|
||||
model_name=self.model_name
|
||||
)
|
||||
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
|
||||
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
self.conversation = Conversation(
|
||||
app_id=self.app.id,
|
||||
app_model_config_id=self.app_model_config.id,
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=self.mode,
|
||||
name='New conversation',
|
||||
inputs=self.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction=system_instruction,
|
||||
system_instruction_tokens=system_instruction_tokens,
|
||||
status='normal',
|
||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
||||
)
|
||||
|
||||
db.session.add(self.conversation)
|
||||
db.session.commit()
|
||||
|
||||
self.message = Message(
|
||||
app_id=self.app.id,
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=self.conversation.id,
|
||||
inputs=self.inputs,
|
||||
query=self.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=self.model_instance.get_currency(),
|
||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
||||
agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
|
||||
)
|
||||
|
||||
db.session.add(self.message)
|
||||
db.session.commit()
|
||||
|
||||
for file in self.files:
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_text(text)
|
||||
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptTemplateParser.remove_template_variables(
|
||||
llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.answer_price_unit = answer_price_unit
|
||||
self.message.provider_response_latency = time.perf_counter() - self.start_at
|
||||
self.message.total_price = total_price
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self.message,
|
||||
conversation=self.conversation,
|
||||
is_first_message=self.is_new_conversation,
|
||||
auto_generate_name=self.auto_generate_name
|
||||
)
|
||||
|
||||
if not by_stopped:
|
||||
self.end()
|
||||
|
||||
def init_chain(self, chain_result: ChainResult):
|
||||
message_chain = MessageChain(
|
||||
message_id=self.message.id,
|
||||
type=chain_result.type,
|
||||
input=json.dumps(chain_result.prompt),
|
||||
output=''
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
|
||||
message_chain.output = json.dumps(chain_result.completion)
|
||||
db.session.commit()
|
||||
|
||||
self._pub_handler.pub_chain(message_chain)
|
||||
|
||||
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
|
||||
message_agent_thought = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=agent_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.commit()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
|
||||
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
|
||||
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.message_price_unit = agent_message_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.answer_price_unit = agent_answer_price_unit
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = agent_model_instance.get_currency()
|
||||
db.session.commit()
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_query_obj.dataset_id,
|
||||
content=dataset_query_obj.query,
|
||||
source='app',
|
||||
source_app_id=self.app.id,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_dataset_query_finish(self, resource: List):
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self.message.id,
|
||||
position=item.get('position'),
|
||||
dataset_id=item.get('dataset_id'),
|
||||
dataset_name=item.get('dataset_name'),
|
||||
document_id=item.get('document_id'),
|
||||
document_name=item.get('document_name'),
|
||||
data_source_type=item.get('data_source_type'),
|
||||
segment_id=item.get('segment_id'),
|
||||
score=item.get('score') if 'score' in item else None,
|
||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||
content=item.get('content'),
|
||||
retriever_from=item.get('retriever_from'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
self.retriever_resource = resource
|
||||
|
||||
def on_message_replace(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_message_replace(text)
|
||||
|
||||
def message_end(self):
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
|
||||
def end(self):
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
class PubHandler:
|
||||
def __init__(self, user: Union[Account, EndUser], task_id: str,
|
||||
message: Message, conversation: Conversation,
|
||||
chain_pub: bool = False, agent_thought_pub: bool = False):
|
||||
self._channel = PubHandler.generate_channel_name(user, task_id)
|
||||
self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
|
||||
|
||||
self._task_id = task_id
|
||||
self._message = message
|
||||
self._conversation = conversation
|
||||
self._chain_pub = chain_pub
|
||||
self._agent_thought_pub = agent_thought_pub
|
||||
|
||||
@classmethod
|
||||
def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str):
|
||||
if not user:
|
||||
raise ValueError("user is required")
|
||||
|
||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||
return "generate_result:{}-{}".format(user_str, task_id)
|
||||
|
||||
@classmethod
|
||||
def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str):
|
||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||
return "generate_result_stopped:{}-{}".format(user_str, task_id)
|
||||
|
||||
def pub_text(self, text: str):
|
||||
content = {
|
||||
'event': 'message',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': str(self._message.id),
|
||||
'text': text,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': str(self._conversation.id)
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_message_replace(self, text: str):
|
||||
content = {
|
||||
'event': 'message_replace',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': str(self._message.id),
|
||||
'text': text,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': str(self._conversation.id)
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_chain(self, message_chain: MessageChain):
|
||||
if self._chain_pub:
|
||||
content = {
|
||||
'event': 'chain',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_chain.id,
|
||||
'type': message_chain.type,
|
||||
'input': json.loads(message_chain.input),
|
||||
'output': json.loads(message_chain.output),
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
|
||||
if self._agent_thought_pub:
|
||||
content = {
|
||||
'event': 'agent_thought',
|
||||
'data': {
|
||||
'id': message_agent_thought.id,
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_agent_thought.message_chain_id,
|
||||
'position': message_agent_thought.position,
|
||||
'thought': message_agent_thought.thought,
|
||||
'tool': message_agent_thought.tool,
|
||||
'tool_input': message_agent_thought.tool_input,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_message_end(self, retriever_resource: List):
|
||||
content = {
|
||||
'event': 'message_end',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
if retriever_resource:
|
||||
content['data']['retriever_resources'] = retriever_resource
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_end(self):
|
||||
content = {
|
||||
'event': 'end',
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def pub_error(cls, user: Union[Account, EndUser], task_id: str, e):
|
||||
content = {
|
||||
'error': type(e).__name__,
|
||||
'description': e.description if getattr(e, 'description', None) is not None else str(e)
|
||||
}
|
||||
|
||||
channel = cls.generate_channel_name(user, task_id)
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
def _is_stopped(self):
|
||||
return redis_client.get(self._stopped_cache_key) is not None
|
||||
|
||||
@classmethod
|
||||
def ping(cls, user: Union[Account, EndUser], task_id: str):
|
||||
content = {
|
||||
'event': 'ping'
|
||||
}
|
||||
|
||||
channel = cls.generate_channel_name(user, task_id)
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def stop(cls, user: Union[Account, EndUser], task_id: str):
|
||||
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationTaskInterruptException(Exception):
|
||||
pass
|
||||
@@ -3,7 +3,8 @@ from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import requests
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
|
||||
from flask import current_app
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
@@ -11,6 +12,13 @@ from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
@@ -49,14 +57,34 @@ class FileExtractor:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
if is_automatic:
|
||||
loader = UnstructuredFileLoader(
|
||||
file_path, strategy="hi_res", mode="elements"
|
||||
)
|
||||
# loader = UnstructuredAPIFileLoader(
|
||||
# file_path=filenames[0],
|
||||
# api_key="FAKE_API_KEY",
|
||||
# )
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
|
||||
50
api/core/data_loader/loader/unstructured/unstructured_eml.py
Normal file
50
api/core/data_loader/loader/unstructured/unstructured_eml.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import logging
|
||||
import base64
|
||||
from typing import List
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredEmailLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
for element in elements:
|
||||
element_text = element.text.strip()
|
||||
|
||||
padding_needed = 4 - len(element_text) % 4
|
||||
element_text += '=' * padding_needed
|
||||
|
||||
element_decode = base64.b64decode(element_text)
|
||||
soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser')
|
||||
element.text = soup.get_text()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
||||
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMarkdownLoader(BaseLoader):
|
||||
"""Load md files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
remove_hyperlinks: Whether to remove hyperlinks from the text.
|
||||
|
||||
remove_images: Whether to remove images from the text.
|
||||
|
||||
encoding: File encoding to use. If `None`, the file will be loaded
|
||||
with the default system encoding.
|
||||
|
||||
autodetect_encoding: Whether to try to autodetect the file encoding
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
40
api/core/data_loader/loader/unstructured/unstructured_msg.py
Normal file
40
api/core/data_loader/loader/unstructured/unstructured_msg.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMsgLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
47
api/core/data_loader/loader/unstructured/unstructured_ppt.py
Normal file
47
api/core/data_loader/loader/unstructured/unstructured_ppt.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
|
||||
combined_texts = list(text_by_page.values())
|
||||
documents = []
|
||||
for combined_text in combined_texts:
|
||||
text = combined_text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user