mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 14:04:17 +00:00
Compare commits
150 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f637ead38 | ||
|
|
b521aafd26 | ||
|
|
a84e15b8cc | ||
|
|
0c330fc020 | ||
|
|
cfbb7bec58 | ||
|
|
3b357f51a6 | ||
|
|
09acf215f0 | ||
|
|
07dd8b94ed | ||
|
|
ef308fd121 | ||
|
|
fce64d760b | ||
|
|
f0c9bb7c91 | ||
|
|
d8672796b0 | ||
|
|
5929e84036 | ||
|
|
83063532a0 | ||
|
|
07279558a5 | ||
|
|
2166473852 | ||
|
|
44397e3062 | ||
|
|
883a0a0e6a | ||
|
|
b5ed81b349 | ||
|
|
625b0afa52 | ||
|
|
2660fbaa20 | ||
|
|
9e37702d24 | ||
|
|
bc11c6a7f2 | ||
|
|
10e9766fd3 | ||
|
|
6d24a2cb87 | ||
|
|
0a4dfaeaf9 | ||
|
|
c0a4fd145c | ||
|
|
70f16e1a0b | ||
|
|
cb27571e9f | ||
|
|
0518da5819 | ||
|
|
d2797abdb4 | ||
|
|
bf3ee660e0 | ||
|
|
68406b9906 | ||
|
|
6f7fd6613a | ||
|
|
6d5b386394 | ||
|
|
1ea18a2922 | ||
|
|
f8f4b961a1 | ||
|
|
57565db531 | ||
|
|
d844420c07 | ||
|
|
34634bddf1 | ||
|
|
c97b7f6748 | ||
|
|
76cc19f525 | ||
|
|
5baaebb3fd | ||
|
|
9d072920da | ||
|
|
965ca36525 | ||
|
|
b4988ce20c | ||
|
|
d3d617239f | ||
|
|
6c3b34a61d | ||
|
|
d76d1adb59 | ||
|
|
cadc6b171e | ||
|
|
fdae2a20ae | ||
|
|
45701a81e9 | ||
|
|
409e0c8e1c | ||
|
|
7076d41b29 | ||
|
|
5a6cb69951 | ||
|
|
11a75ee78a | ||
|
|
b9b692d71d | ||
|
|
d8f8afcbd0 | ||
|
|
8cb62ef31a | ||
|
|
bb5d5fc683 | ||
|
|
2fc0dcc10a | ||
|
|
9fd55157d6 | ||
|
|
6c384dba71 | ||
|
|
9730297381 | ||
|
|
99e80a8ed0 | ||
|
|
26fef2d481 | ||
|
|
c9e65f4221 | ||
|
|
20bd33fada | ||
|
|
bd0af2e921 | ||
|
|
4ab66299d4 | ||
|
|
42227f93c0 | ||
|
|
89fcf4ea7c | ||
|
|
3322710dac | ||
|
|
404bf11d8c | ||
|
|
60a2ecbd17 | ||
|
|
828822243a | ||
|
|
2068ae215e | ||
|
|
d4262ecceb | ||
|
|
8be7d8a635 | ||
|
|
c038040e1b | ||
|
|
21450b8a51 | ||
|
|
5fc1bd026a | ||
|
|
d60f1a5601 | ||
|
|
da83f8403e | ||
|
|
4ff17af5de | ||
|
|
a9d1b4e6d7 | ||
|
|
66612075d2 | ||
|
|
b921c55677 | ||
|
|
bdc5e9ceb0 | ||
|
|
f2b2effc4b | ||
|
|
301e0496ff | ||
|
|
98660e1f97 | ||
|
|
6cf93379b3 | ||
|
|
8639abec97 | ||
|
|
d5361b8d09 | ||
|
|
6bfdfab6f3 | ||
|
|
bec998ab94 | ||
|
|
77636945fb | ||
|
|
fd5c45ae10 | ||
|
|
ad71386adf | ||
|
|
043517717e | ||
|
|
76c52300a2 | ||
|
|
dda32c6880 | ||
|
|
ac4bb5c35f | ||
|
|
a96cae4f44 | ||
|
|
7cb75cb2e7 | ||
|
|
0940084fd2 | ||
|
|
95ad06c8c3 | ||
|
|
3c13c4f3ee | ||
|
|
2fe938b7da | ||
|
|
784da52ea6 | ||
|
|
78524a56ed | ||
|
|
6c614f0c1f | ||
|
|
d42df4ed04 | ||
|
|
6d94126368 | ||
|
|
e0f72d2791 | ||
|
|
3e51710fe6 | ||
|
|
7bfdca7a53 | ||
|
|
48d5628fd4 | ||
|
|
c8fb619d37 | ||
|
|
57024614bd | ||
|
|
a31b502668 | ||
|
|
e58c3ac374 | ||
|
|
00f4e6ec44 | ||
|
|
6355e61eb8 | ||
|
|
27828f44b9 | ||
|
|
9525ca08b9 | ||
|
|
501caf0a69 | ||
|
|
c17baef172 | ||
|
|
21ade71bad | ||
|
|
23e02d8eb0 | ||
|
|
86286e1ac8 | ||
|
|
7bbe12b2bd | ||
|
|
e65a2a400d | ||
|
|
741079f317 | ||
|
|
0f5d4fd11b | ||
|
|
8eae206715 | ||
|
|
7434d44412 | ||
|
|
8394bbd47f | ||
|
|
14a2eeba0c | ||
|
|
a18dde9b0d | ||
|
|
8438d820ad | ||
|
|
e19ad023d2 | ||
|
|
0695f08f05 | ||
|
|
22ab4721e2 | ||
|
|
51f23c5dc2 | ||
|
|
1f48e3d44a | ||
|
|
0113627d7b | ||
|
|
0a5de0ff0b | ||
|
|
9c4bad8f1e |
26
.github/workflows/tool-tests.yaml
vendored
Normal file
26
.github/workflows/tool-tests.yaml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Run Tool Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: ./api/requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r ./api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest ./api/tests/integration_tests/tools/test_all_provider.py
|
||||
@@ -91,6 +91,8 @@ To validate your set up, head over to [http://localhost:3000](http://localhost:3
|
||||
|
||||
If you are adding a model provider, [this guide](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md) is for you.
|
||||
|
||||
If you are adding a tool provider to Agent or Workflow, [this guide](./api/core/tools/README.md) is for you.
|
||||
|
||||
To help you quickly navigate where your contribution fits, a brief, annotated outline of Dify's backend & frontend is as follows:
|
||||
|
||||
### Backend
|
||||
|
||||
@@ -1,57 +1,155 @@
|
||||
# 贡献
|
||||
所以你想为 Dify 做贡献 - 这太棒了,我们迫不及待地想看到你的贡献。作为一家人员和资金有限的初创公司,我们有着雄心勃勃的目标,希望设计出最直观的工作流程来构建和管理 LLM 应用程序。社区的任何帮助都是宝贵的。
|
||||
|
||||
感谢您对 [Dify](https://dify.ai) 的兴趣,并希望您能够做出贡献!在开始之前,请先阅读[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)并查看[现有问题](https://github.com/langgenius/dify/issues)。
|
||||
本文档介绍了如何设置开发环境以构建和测试 [Dify](https://dify.ai)。
|
||||
考虑到我们的现状,我们需要灵活快速地交付,但我们也希望确保像你这样的贡献者在贡献过程中获得尽可能顺畅的体验。我们为此编写了这份贡献指南,旨在让你熟悉代码库和我们与贡献者的合作方式,以便你能快速进入有趣的部分。
|
||||
|
||||
### 安装依赖项
|
||||
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎任何反馈以供我们改进。
|
||||
|
||||
您需要在计算机上安装和配置以下依赖项才能构建 [Dify](https://dify.ai):
|
||||
在许可方面,请花一分钟阅读我们简短的[许可证和贡献者协议](./license)。社区还遵守[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
|
||||
- [Git](http://git-scm.com/)
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) 版本 8.x.x 或 [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) 版本 3.10.x
|
||||
## 在开始之前
|
||||
|
||||
## 本地开发
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或[创建](https://github.com/langgenius/dify/issues/new/choose)一个新问题。我们将问题分为两类:
|
||||
|
||||
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose。
|
||||
### 功能请求:
|
||||
|
||||
### Fork存储库
|
||||
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss)制作了一个很好的[功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
|
||||
|
||||
您需要 fork [Git 仓库](https://github.com/langgenius/dify)。
|
||||
* 如果您想从现有问题中选择一个,请在其下方留下评论表示您的意愿。
|
||||
|
||||
### 克隆存储库
|
||||
相关方向的团队成员将参与其中。如果一切顺利,他们将批准您开始编码。在此之前,请不要开始工作,以免我们提出更改导致您的工作付诸东流。
|
||||
|
||||
克隆您在 GitHub 上 fork 的仓库:
|
||||
根据所提议的功能所属的领域不同,您可能需要与不同的团队成员交流。以下是我们团队成员目前正在从事的各个领域的概述:
|
||||
|
||||
| Member | Scope |
|
||||
| ------------------------------------------------------------ | ---------------------------------------------------- |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | Architecting Agents |
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG pipeline design |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | Building workflow orchestrations |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Making our frontend a breeze to use |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Developer experience, points of contact for anything |
|
||||
| [@takatost](https://github.com/takatost) | Overall product direction and architecture |
|
||||
|
||||
How we prioritize:
|
||||
|
||||
| Feature Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| High-Priority Features as being labeled by a team member | High Priority |
|
||||
| Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority |
|
||||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
|
||||
### 其他任何事情(例如bug报告、性能优化、拼写错误更正):
|
||||
* 立即开始编码。
|
||||
|
||||
How we prioritize:
|
||||
|
||||
| Issue Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| Bugs in core functions (cannot login, applications not working, security loopholes) | Critical |
|
||||
| Non-critical bugs, performance boosts | Medium Priority |
|
||||
| Minor fixes (typos, confusing but working UI) | Low Priority |
|
||||
|
||||
|
||||
## 安装
|
||||
|
||||
以下是设置Dify进行开发的步骤:
|
||||
|
||||
### 1. Fork该仓库
|
||||
|
||||
### 2. 克隆仓库
|
||||
|
||||
从终端克隆fork的仓库:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
```
|
||||
|
||||
### 安装后端
|
||||
### 3. 验证依赖项
|
||||
|
||||
要了解如何安装后端应用程序,请参阅[后端 README](api/README.md)。
|
||||
Dify 依赖以下工具和库:
|
||||
|
||||
### 安装前端
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
|
||||
要了解如何安装前端应用程序,请参阅[前端 README](web/README.md)。
|
||||
### 4. 安装
|
||||
|
||||
### 在浏览器中访问 Dify
|
||||
Dify由后端和前端组成。通过`cd api/`导航到后端目录,然后按照[后端README](api/README.md)进行安装。在另一个终端中,通过`cd web/`导航到前端目录,然后按照[前端README](web/README.md)进行安装。
|
||||
|
||||
最后,您现在可以访问 [http://localhost:3000](http://localhost:3000) 在本地环境中查看 [Dify](https://dify.ai)。
|
||||
查看[安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq)以获取常见问题列表和故障排除步骤。
|
||||
|
||||
## 创建拉取请求
|
||||
### 5. 在浏览器中访问Dify
|
||||
|
||||
在进行更改后,打开一个拉取请求(PR)。提交拉取请求后,Dify 团队/社区的其他人将与您一起审查它。
|
||||
为了验证您的设置,打开浏览器并访问[http://localhost:3000](http://localhost:3000)(默认或您自定义的URL和端口)。现在您应该看到Dify正在运行。
|
||||
|
||||
如果遇到问题,比如合并冲突或不知道如何打开拉取请求,请查看 GitHub 的[拉取请求教程](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests),了解如何解决合并冲突和其他问题。一旦您的 PR 被合并,您将自豪地被列为[贡献者表](https://github.com/langgenius/dify/graphs/contributors)中的一员。
|
||||
## 开发
|
||||
|
||||
## 社区渠道
|
||||
如果您要添加模型提供程序,请参考[此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
|
||||
|
||||
遇到困难了吗?有任何问题吗? 加入 [Discord Community Server](https://discord.gg/AhzKf7dNgk),我们将为您提供帮助。
|
||||
如果您要向Agent或Workflow添加工具提供程序,请参考[此指南](./api/core/tools/README.md)。
|
||||
|
||||
### 多语言支持
|
||||
为了帮助您快速了解您的贡献在哪个部分,以下是Dify后端和前端的简要注释大纲:
|
||||
|
||||
需要参与贡献翻译内容,请参阅[前端多语言翻译 README](web/i18n/README_CN.md)。
|
||||
### 后端
|
||||
|
||||
Dify的后端使用Python编写,使用[Flask](https://flask.palletsprojects.com/en/3.0.x/)框架。它使用[SQLAlchemy](https://www.sqlalchemy.org/)作为ORM,使用[Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)作为任务队列。授权逻辑通过Flask-login进行处理。
|
||||
|
||||
```
|
||||
[api/]
|
||||
├── constants // Constant settings used throughout code base.
|
||||
├── controllers // API route definitions and request handling logic.
|
||||
├── core // Core application orchestration, model integrations, and tools.
|
||||
├── docker // Docker & containerization related configurations.
|
||||
├── events // Event handling and processing
|
||||
├── extensions // Extensions with 3rd party frameworks/platforms.
|
||||
├── fields // field definitions for serialization/marshalling.
|
||||
├── libs // Reusable libraries and helpers.
|
||||
├── migrations // Scripts for database migration.
|
||||
├── models // Database models & schema definitions.
|
||||
├── services // Specifies business logic.
|
||||
├── storage // Private key storage.
|
||||
├── tasks // Handling of async tasks and background jobs.
|
||||
└── tests
|
||||
```
|
||||
|
||||
### 前端
|
||||
|
||||
该网站使用基于Typescript的[Next.js](https://nextjs.org/)模板进行引导,并使用[Tailwind CSS](https://tailwindcss.com/)进行样式设计。[React-i18next](https://react.i18next.com/)用于国际化。
|
||||
|
||||
```
|
||||
[web/]
|
||||
├── app // layouts, pages, and components
|
||||
│ ├── (commonLayout) // common layout used throughout the app
|
||||
│ ├── (shareLayout) // layouts specifically shared across token-specific sessions
|
||||
│ ├── activate // activate page
|
||||
│ ├── components // shared by pages and layouts
|
||||
│ ├── install // install page
|
||||
│ ├── signin // signin page
|
||||
│ └── styles // globally shared styles
|
||||
├── assets // Static assets
|
||||
├── bin // scripts ran at build step
|
||||
├── config // adjustable settings and options
|
||||
├── context // shared contexts used by different portions of the app
|
||||
├── dictionaries // Language-specific translate files
|
||||
├── docker // container configurations
|
||||
├── hooks // Reusable hooks
|
||||
├── i18n // Internationalization configuration
|
||||
├── models // describes data models & shapes of API responses
|
||||
├── public // meta assets like favicon
|
||||
├── service // specifies shapes of API actions
|
||||
├── test
|
||||
├── types // descriptions of function params and return values
|
||||
└── utils // Shared utility functions
|
||||
```
|
||||
|
||||
## 提交你的 PR
|
||||
|
||||
最后,是时候向我们的仓库提交一个拉取请求(PR)了。对于重要的功能,我们首先将它们合并到 `deploy/dev` 分支进行测试,然后再合并到 `main` 分支。如果你遇到合并冲突或者不知道如何提交拉取请求的问题,请查看 [GitHub 的拉取请求教程](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests)。
|
||||
|
||||
就是这样!一旦你的 PR 被合并,你将成为我们 [README](https://github.com/langgenius/dify/blob/main/README.md) 中的贡献者。
|
||||
|
||||
## 获取帮助
|
||||
|
||||
如果你在贡献过程中遇到困难或者有任何问题,可以通过相关的 GitHub 问题提出你的疑问,或者加入我们的 [Discord](https://discord.gg/AhzKf7dNgk) 进行快速交流。
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
# コントリビュート
|
||||
|
||||
[Dify](https://dify.ai) に興味を持ち、貢献したいと思うようになったことに感謝します!始める前に、
|
||||
[行動規範](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)を読み、
|
||||
[既存の問題](https://github.com/langgenius/langgenius-gateway/issues)をチェックしてください。
|
||||
本ドキュメントは、[Dify](https://dify.ai) をビルドしてテストするための開発環境の構築方法を説明するものです。
|
||||
|
||||
### 依存関係のインストール
|
||||
|
||||
[Dify](https://dify.ai)をビルドするには、お使いのマシンに以下の依存関係をインストールし、設定する必要があります:
|
||||
|
||||
- [Git](http://git-scm.com/)
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) バージョン 8.x.x もしくは [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) バージョン 3.10.x
|
||||
|
||||
## ローカル開発
|
||||
|
||||
開発環境を構築するには、プロジェクトの git リポジトリをフォークし、適切なパッケージマネージャを使用してバックエンドとフロントエンドの依存関係をインストールし、docker-compose スタックを実行するように作成します。
|
||||
|
||||
### リポジトリのフォーク
|
||||
|
||||
[リポジトリ](https://github.com/langgenius/dify) をフォークする必要があります。
|
||||
|
||||
### リポジトリのクローン
|
||||
|
||||
GitHub でフォークしたリポジトリのクローンを作成する:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
```
|
||||
|
||||
### バックエンドのインストール
|
||||
|
||||
バックエンドアプリケーションのインストール方法については、[Backend README](api/README.md) を参照してください。
|
||||
|
||||
### フロントエンドのインストール
|
||||
|
||||
フロントエンドアプリケーションのインストール方法については、[Frontend README](web/README.md) を参照してください。
|
||||
|
||||
### ブラウザで dify にアクセス
|
||||
|
||||
[Dify](https://dify.ai) をローカル環境で見ることができるようになりました [http://localhost:3000](http://localhost:3000)。
|
||||
|
||||
## プルリクエストの作成
|
||||
|
||||
変更後、プルリクエスト (PR) をオープンしてください。プルリクエストを提出すると、Dify チーム/コミュニティの他の人があなたと一緒にそれをレビューします。
|
||||
|
||||
マージコンフリクトなどの問題が発生したり、プルリクエストの開き方がわからなくなったりしませんでしたか? [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests) で、マージコンフリクトやその他の問題を解決する方法をチェックしてみてください。あなたの PR がマージされると、[コントリビュータチャート](https://github.com/langgenius/langgenius-gateway/graphs/contributors)にコントリビュータとして誇らしげに掲載されます。
|
||||
|
||||
## コミュニティチャンネル
|
||||
|
||||
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/j3XRWSPBf7) に参加してください。私たちがお手伝いします!
|
||||
@@ -21,6 +21,11 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
</a>
|
||||
</p>
|
||||
|
||||
**Dify** is an LLM application development platform that has helped built over **100,000** applications. It integrates BaaS and LLMOps, covering the essential tech stack for building generative AI-native applications, including a built-in RAG engine. Dify allows you to **deploy your own version of Assistants API and GPTs, based on any LLMs.**
|
||||
|
||||
@@ -55,7 +60,8 @@ You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabi
|
||||
|
||||
**3. RAG Engine**: Includes various RAG capabilities based on full-text indexing or vector database embeddings, allowing direct upload of PDFs, TXTs, and other text formats.
|
||||
|
||||
**4. Agents**: A Function Calling based Agent framework that allows users to configure what they see is what they get. Dify includes basic plugin capabilities like Google Search.
|
||||
**4. AI Agent**: Based on Function Calling and ReAct, the Agent inference framework allows users to customize tools, what you see is what you get. Dify provides more than a dozen built-in tool calling capabilities, such as Google Search, DELL·E, Stable Diffusion, WolframAlpha, etc.
|
||||
|
||||
|
||||
**5. Continuous Operations**: Monitor and analyze application logs and performance, continuously improving Prompts, datasets, or models using production data.
|
||||
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://mp.weixin.qq.com/s/TnyfIuH-tPi9o1KNjwVArw" target="_blank">
|
||||
Dify 发布 AI Agent 能力:基于不同的大型语言模型构建 GPTs 和 Assistants
|
||||
</a>
|
||||
</p>
|
||||
|
||||
Dify 是一个 LLM 应用开发平台,已经有超过 10 万个应用基于 Dify.AI 构建。它融合了 Backend as Service 和 LLMOps 的理念,涵盖了构建生成式 AI 原生应用所需的核心技术栈,包括一个内置 RAG 引擎。使用 Dify,你可以基于任何模型自部署类似 Assistants API 和 GPTs 的能力。
|
||||
|
||||

|
||||
@@ -53,7 +59,7 @@ Dify 具有模型中立性,相较 LangChain 等硬编码开发库 Dify 是一
|
||||
|
||||
**3. RAG引擎**:包括各种基于全文索引或向量数据库嵌入的 RAG 能力,允许直接上传 PDF、TXT 等各种文本格式。
|
||||
|
||||
**4. Agent**:基于函数调用的 Agent框架,允许用户自定义配置,所见即所得。Dify 提供了基本的插件能力,如谷歌搜索。
|
||||
**4. AI Agent**:基于 Function Calling 和 ReAct 的 Agent 推理框架,允许用户自定义工具,所见即所得。Dify 提供了十多种内置工具调用能力,如谷歌搜索、DELL·E、Stable Diffusion、WolframAlpha 等。
|
||||
|
||||
**5. 持续运营**:监控和分析应用日志和性能,使用生产数据持续改进 Prompt、数据集或模型。
|
||||
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
</a>
|
||||
</p>
|
||||
|
||||
**Dify** es una plataforma de desarrollo de aplicaciones para modelos de lenguaje de gran tamaño (LLM) que ya ha visto la creación de más de **100,000** aplicaciones basadas en Dify.AI. Integra los conceptos de Backend como Servicio y LLMOps, cubriendo el conjunto de tecnologías esenciales requerido para construir aplicaciones nativas de inteligencia artificial generativa, incluyendo un motor RAG incorporado. Con Dify, **puedes auto-desplegar capacidades similares a las de Assistants API y GPTs basadas en cualquier LLM.**
|
||||
|
||||

|
||||
@@ -52,7 +58,7 @@ Dify se caracteriza por su neutralidad de modelo y es un conjunto tecnológico c
|
||||
|
||||
**3. Motor RAG**: Incluye varias capacidades RAG basadas en indexación de texto completo o incrustaciones de base de datos vectoriales, permitiendo la carga directa de PDFs, TXTs y otros formatos de texto.
|
||||
|
||||
**4. Agentes**: Un marco de Agentes basado en Llamadas de Función que permite a los usuarios configurar lo que ven es lo que obtienen. Dify incluye capacidades básicas de plugins como la Búsqueda de Google.
|
||||
**4. Agente de IA**: Basado en la llamada de funciones y ReAct, el marco de inferencia del Agente permite a los usuarios personalizar las herramientas, lo que ves es lo que obtienes. Dify proporciona más de una docena de capacidades de llamada de herramientas incorporadas, como Búsqueda de Google, DELL·E, Difusión Estable, WolframAlpha, etc.
|
||||
|
||||
**5. Operaciones Continuas**: Monitorear y analizar registros de aplicaciones y rendimiento, mejorando continuamente Prompts, conjuntos de datos o modelos usando datos de producción.
|
||||
|
||||
|
||||
@@ -21,6 +21,13 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
**Dify** est une plateforme de développement d'applications LLM qui a déjà vu plus de **100,000** applications construites sur Dify.AI. Elle intègre les concepts de Backend as a Service et LLMOps, couvrant la pile technologique de base requise pour construire des applications natives d'IA générative, y compris un moteur RAG intégré. Avec Dify, **vous pouvez auto-déployer des capacités similaires aux API Assistants et GPT basées sur n'importe quels LLM.**
|
||||
|
||||

|
||||
@@ -52,7 +59,7 @@ Dify présente une neutralité de modèle et est une pile technologique complèt
|
||||
|
||||
**3\. Moteur RAG**: Comprend diverses capacités RAG basées sur l'indexation de texte intégral ou les embeddings de base de données vectorielles, permettant le chargement direct de PDF, TXT et autres formats de texte.
|
||||
|
||||
**4\. Agents**: Un framework d'agents basé sur l'appel de fonctions qui permet aux utilisateurs de configurer ce qu'ils voient est ce qu'ils obtiennent. Dify comprend des capacités de plug-in de base comme Google Search.
|
||||
**4\. AI Agent**: Basé sur l'appel de fonction et ReAct, le framework d'inférence de l'Agent permet aux utilisateurs de personnaliser les outils, ce que vous voyez est ce que vous obtenez. Dify propose plus d'une douzaine de capacités d'appel d'outils intégrées, telles que la recherche Google, DELL·E, Diffusion Stable, WolframAlpha, etc.
|
||||
|
||||
**5\. Opérations continues**: Surveillez et analysez les journaux et les performances des applications, améliorez en continu les invites, les datasets ou les modèles à l'aide de données de production.
|
||||
|
||||
|
||||
@@ -21,6 +21,13 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
"Difyは、既にDify.AI上で10万以上のアプリケーションが構築されているLLMアプリケーション開発プラットフォームです。バックエンド・アズ・ア・サービスとLLMOpsの概念を統合し、組み込みのRAGエンジンを含む、生成AIネイティブアプリケーションを構築するためのコアテックスタックをカバーしています。Difyを使用すると、どのLLMに基づいても、Assistants APIやGPTのような機能を自己デプロイすることができます。"
|
||||
|
||||
Please note that translating complex technical terms can sometimes result in slight variations in meaning due to differences in language nuances.
|
||||
@@ -54,7 +61,7 @@ Difyはモデルニュートラルであり、LangChainのようなハードコ
|
||||
|
||||
**3\. RAGエンジン**: フルテキストインデックスまたはベクトルデータベース埋め込みに基づくさまざまなRAG機能を含み、PDF、TXT、その他のテキストフォーマットの直接アップロードを可能にします。
|
||||
|
||||
**4\. エージェント**: ユーザーが sees what they get を設定できる関数呼び出しベースのエージェントフレームワーク。 Difyには、Google検索などの基本的なプラグイン機能が含まれています。
|
||||
**4. AIエージェント**: 関数呼び出しとReActに基づくAgent推論フレームワークにより、ユーザーはツールをカスタマイズすることができます。Difyは、Google検索、DELL·E、Stable Diffusion、WolframAlphaなど、十数種類の組み込みツール呼び出し機能を提供しています。
|
||||
|
||||
**5\. 継続的運用**: アプリケーションログとパフォーマンスを監視および分析し、運用データを使用してプロンプト、データセット、またはモデルを継続的に改善します。
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ Dify Daq rIn neutrality 'ej Hoch, LangChain tInHar HubwI'. maH Daqbe'law' Qawqar
|
||||
|
||||
**3. RAG Engine**: RAG vaD tIqpu' lo'taH indexing qor neH vector database wa' embeddings wIj, PDFs, TXTs, 'ej ghojmoHmoH HIq qorlIj je upload.
|
||||
|
||||
**4. jenSuvpu'**: jenbe' SuDqang naQ moDwu' jenSuvpu' porgh cha'logh choHvam. Dify Google Search Hur vItlhutlh plugin choH.
|
||||
**4. AI Agent**: Function Calling 'ej ReAct Daq Hurmey, Agent inference framework Hoch users customize tools, vaj 'oH QaQ. Dify Hoch loS ghaH 'ej wa'vatlh built-in tool calling capabilities, Google Search, DELL·E, Stable Diffusion, WolframAlpha, 'ej.
|
||||
|
||||
**5. QaS muDHa'wI': cha'logh wa' pIq mI' logs 'ej quv yIn, vItlhutlh tIq 'e'wIj lo'taHmoHmoH Prompts, vItlhutlh, Hurmey ghaH production data jatlh.
|
||||
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
# packages install stage
|
||||
FROM python:3.10-slim AS base
|
||||
# base image
|
||||
FROM python:3.10-slim-bookworm AS base
|
||||
|
||||
LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
# install packages
|
||||
FROM base as packages
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ python3-dev libc-dev libffi-dev
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
COPY requirements.txt /requirements.txt
|
||||
|
||||
RUN pip install --prefix=/pkg -r requirements.txt
|
||||
|
||||
# build stage
|
||||
FROM python:3.10-slim AS builder
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
@@ -30,11 +32,11 @@ ENV TZ UTC
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends bash curl wget vim nodejs \
|
||||
&& apt-get install -y --no-install-recommends curl wget vim nodejs ffmpeg libgmp-dev libmpfr-dev libmpc-dev \
|
||||
&& apt-get autoremove \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=base /pkg /usr/local
|
||||
COPY --from=packages /pkg /usr/local
|
||||
COPY . /app/api/
|
||||
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
@@ -30,7 +30,7 @@ from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from libs.passport import PassportService
|
||||
# DO NOT REMOVE BELOW
|
||||
from models import account, dataset, model, source, task, tool, web
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
@@ -124,6 +124,7 @@ def load_user_from_request(request_from_flask_login):
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
|
||||
651
api/commands.py
651
api/commands.py
@@ -1,31 +1,20 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import click
|
||||
import qdrant_client
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from flask import Flask, current_app
|
||||
from flask import current_app
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetQuery, Document
|
||||
from models.model import Account, App, AppModelConfig, Message, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel, ProviderQuotaType, ProviderType
|
||||
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
|
||||
from tqdm import tqdm
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.model import Account
|
||||
from models.provider import Provider, ProviderModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
@@ -34,15 +23,22 @@ from werkzeug.exceptions import NotFound
|
||||
@click.option('--new-password', prompt=True, help='the new password.')
|
||||
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
|
||||
def reset_password(email, new_password, password_confirm):
|
||||
"""
|
||||
Reset password of owner account
|
||||
Only available in SELF_HOSTED mode
|
||||
"""
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
|
||||
return
|
||||
|
||||
account = db.session.query(Account). \
|
||||
filter(Account.email == email). \
|
||||
one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
@@ -68,15 +64,22 @@ def reset_password(email, new_password, password_confirm):
|
||||
@click.option('--new-email', prompt=True, help='the new email.')
|
||||
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
|
||||
def reset_email(email, new_email, email_confirm):
|
||||
"""
|
||||
Replace account email
|
||||
:return:
|
||||
"""
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
|
||||
return
|
||||
|
||||
account = db.session.query(Account). \
|
||||
filter(Account.email == email). \
|
||||
one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
@@ -96,6 +99,11 @@ def reset_email(email, new_email, email_confirm):
|
||||
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
|
||||
' this operation cannot be rolled back!', fg='red'))
|
||||
def reset_encrypt_key_pair():
|
||||
"""
|
||||
Reset the encrypted key pair of workspace for encrypt LLM credentials.
|
||||
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
||||
Only support SELF_HOSTED mode.
|
||||
"""
|
||||
if current_app.config['EDITION'] != 'SELF_HOSTED':
|
||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
||||
return
|
||||
@@ -115,201 +123,11 @@ def reset_encrypt_key_pair():
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
|
||||
|
||||
@click.command('generate-invitation-codes', help='Generate invitation codes.')
|
||||
@click.option('--batch', help='The batch of invitation codes.')
|
||||
@click.option('--count', prompt=True, help='Invitation codes count.')
|
||||
def generate_invitation_codes(batch, count):
|
||||
if not batch:
|
||||
now = datetime.datetime.now()
|
||||
batch = now.strftime('%Y%m%d%H%M%S')
|
||||
|
||||
if not count or int(count) <= 0:
|
||||
click.echo(click.style('sorry. the count must be greater than 0.', fg='red'))
|
||||
return
|
||||
|
||||
count = int(count)
|
||||
|
||||
click.echo('Start generate {} invitation codes for batch {}.'.format(count, batch))
|
||||
|
||||
codes = ''
|
||||
for i in range(count):
|
||||
code = generate_invitation_code()
|
||||
invitation_code = InvitationCode(
|
||||
code=code,
|
||||
batch=batch
|
||||
)
|
||||
db.session.add(invitation_code)
|
||||
click.echo(code)
|
||||
|
||||
codes += code + "\n"
|
||||
db.session.commit()
|
||||
|
||||
filename = 'storage/invitation-codes-{}.txt'.format(batch)
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
f.write(codes)
|
||||
|
||||
click.echo(click.style(
|
||||
'Congratulations! Generated {} invitation codes for batch {} and saved to the file \'{}\''.format(count, batch,
|
||||
filename),
|
||||
fg='green'))
|
||||
|
||||
|
||||
def generate_invitation_code():
|
||||
code = generate_upper_string()
|
||||
while db.session.query(InvitationCode).filter(InvitationCode.code == code).count() > 0:
|
||||
code = generate_upper_string()
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def generate_upper_string():
|
||||
letters_digits = string.ascii_uppercase + string.digits
|
||||
result = ""
|
||||
for i in range(8):
|
||||
result += random.choice(letters_digits)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
|
||||
def recreate_all_dataset_indexes():
|
||||
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
|
||||
recreate_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
try:
|
||||
click.echo('Recreating dataset index: {}'.format(dataset.id))
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index and index._is_origin():
|
||||
index.recreate_dataset(dataset)
|
||||
recreate_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('clean-unused-dataset-indexes', help='Clean unused dataset indexes.')
|
||||
def clean_unused_dataset_indexes():
|
||||
click.echo(click.style('Start clean unused dataset indexes.', fg='green'))
|
||||
clean_days = int(current_app.config.get('CLEAN_DAY_SETTING'))
|
||||
start_at = time.perf_counter()
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
dataset_query = db.session.query(DatasetQuery).filter(
|
||||
DatasetQuery.created_at > thirty_days_ago,
|
||||
DatasetQuery.dataset_id == dataset.id
|
||||
).all()
|
||||
if not dataset_query or len(dataset_query) == 0:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.indexing_status == 'completed',
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at > thirty_days_ago
|
||||
).all()
|
||||
if not documents or len(documents) == 0:
|
||||
try:
|
||||
# remove index
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
if dataset.collection_binding_id:
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
else:
|
||||
vector_index.delete()
|
||||
kw_index.delete()
|
||||
# update document
|
||||
update_params = {
|
||||
Document.enabled: False
|
||||
}
|
||||
|
||||
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
|
||||
db.session.commit()
|
||||
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
|
||||
|
||||
|
||||
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
|
||||
def sync_anthropic_hosted_providers():
|
||||
if not hosted_model_providers.anthropic:
|
||||
click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
|
||||
return
|
||||
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
new_quota_limit = hosted_model_providers.anthropic.quota_limit
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.provider_name == 'anthropic',
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.quota_limit != new_quota_limit
|
||||
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for provider in providers:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
|
||||
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
|
||||
original_quota_limit = provider.quota_limit
|
||||
division = math.ceil(new_quota_limit / 1000)
|
||||
|
||||
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
|
||||
else original_quota_limit * division
|
||||
provider.quota_used = division * provider.quota_used
|
||||
db.session.commit()
|
||||
|
||||
count += 1
|
||||
except Exception as e:
|
||||
click.echo(click.style(
|
||||
'Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||
|
||||
|
||||
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
|
||||
def create_qdrant_indexes():
|
||||
"""
|
||||
Migrate other vector database datas to Qdrant.
|
||||
"""
|
||||
click.echo(click.style('Start create qdrant indexes.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
@@ -338,26 +156,7 @@ def create_qdrant_indexes():
|
||||
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
except Exception:
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
continue
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
@@ -392,402 +191,8 @@ def create_qdrant_indexes():
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('update-qdrant-indexes', help='Update qdrant indexes.')
|
||||
def update_qdrant_indexes():
|
||||
click.echo(click.style('Start Update qdrant indexes.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Update dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.update_qdrant_dataset(dataset)
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('normalization-collections', help='restore all collections in one')
|
||||
def normalization_collections():
|
||||
click.echo(click.style('Start normalization collections.', fg='green'))
|
||||
normalization_count = []
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
datasets_result = datasets.items
|
||||
page += 1
|
||||
for i in range(0, len(datasets_result), 5):
|
||||
threads = []
|
||||
sub_datasets = datasets_result[i:i + 5]
|
||||
for dataset in sub_datasets:
|
||||
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'normalization_count': normalization_count
|
||||
})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
|
||||
|
||||
|
||||
@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
|
||||
def add_qdrant_full_text_index():
|
||||
click.echo(click.style('Start add full text index.', fg='green'))
|
||||
binds = db.session.query(DatasetCollectionBinding).all()
|
||||
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
|
||||
qdrant_url = current_app.config['QDRANT_URL']
|
||||
qdrant_api_key = current_app.config['QDRANT_API_KEY']
|
||||
client = qdrant_client.QdrantClient(
|
||||
qdrant_url,
|
||||
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
|
||||
)
|
||||
for bind in binds:
|
||||
try:
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True
|
||||
)
|
||||
client.create_payload_index(bind.collection_name, 'page_content',
|
||||
field_schema=text_index_params)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(
|
||||
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
|
||||
fg='green'))
|
||||
|
||||
|
||||
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
click.echo('restore dataset index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
||||
DatasetCollectionBinding.model_name == embedding_model.name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=embedding_model.model_provider.provider_name,
|
||||
model_name=embedding_model.name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
# index.delete_by_group_id(dataset.id)
|
||||
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
||||
else:
|
||||
click.echo('passed.')
|
||||
normalization_count.append(1)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
|
||||
|
||||
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def update_app_model_configs(batch_size):
|
||||
pre_prompt_template = '{{default_input}}'
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"zh-Hans": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "查询内容",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
|
||||
|
||||
total_records = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.count()
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
|
||||
num_batches = (total_records + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_records, desc="Migrating Data") as pbar:
|
||||
for i in range(num_batches):
|
||||
offset = i * batch_size
|
||||
limit = min(batch_size, total_records - offset)
|
||||
|
||||
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
data_batch = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.order_by(App.created_at) \
|
||||
.offset(offset).limit(limit).all()
|
||||
|
||||
if not data_batch:
|
||||
click.secho("No more data to migrate.", fg='green')
|
||||
break
|
||||
|
||||
try:
|
||||
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
|
||||
for data in data_batch:
|
||||
# click.secho(f"Migrating data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
|
||||
|
||||
if data.pre_prompt is None:
|
||||
data.pre_prompt = pre_prompt_template
|
||||
else:
|
||||
if pre_prompt_template in data.pre_prompt:
|
||||
continue
|
||||
data.pre_prompt += pre_prompt_template
|
||||
|
||||
app_data = db.session.query(App) \
|
||||
.filter(App.id == data.app_id) \
|
||||
.one()
|
||||
|
||||
account_data = db.session.query(Account) \
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.filter(TenantAccountJoin.tenant_id == app_data.tenant_id) \
|
||||
.one_or_none()
|
||||
|
||||
if not account_data:
|
||||
continue
|
||||
|
||||
if data.user_input_form is None or data.user_input_form == 'null':
|
||||
data.user_input_form = json.dumps(user_input_form_template[account_data.interface_language])
|
||||
else:
|
||||
raw_json_data = json.loads(data.user_input_form)
|
||||
raw_json_data.append(user_input_form_template[account_data.interface_language][0])
|
||||
data.user_input_form = json.dumps(raw_json_data)
|
||||
|
||||
# click.secho(f"Updated data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('migrate_default_input_to_dataset_query_variable')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
click.secho("Starting...", fg='green')
|
||||
|
||||
total_records = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.count()
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
|
||||
num_batches = (total_records + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_records, desc="Migrating Data") as pbar:
|
||||
for i in range(num_batches):
|
||||
offset = i * batch_size
|
||||
limit = min(batch_size, total_records - offset)
|
||||
|
||||
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
data_batch = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.order_by(App.created_at) \
|
||||
.offset(offset).limit(limit).all()
|
||||
|
||||
if not data_batch:
|
||||
click.secho("No more data to migrate.", fg='green')
|
||||
break
|
||||
|
||||
try:
|
||||
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
|
||||
for data in data_batch:
|
||||
config = AppModelConfig.to_dict(data)
|
||||
|
||||
tools = config["agent_mode"]["tools"]
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
if not dataset_exists:
|
||||
continue
|
||||
|
||||
user_input_form = config.get("user_input_form", [])
|
||||
for form in user_input_form:
|
||||
paragraph = form.get('paragraph')
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'query':
|
||||
data.dataset_query_variable = 'query'
|
||||
break
|
||||
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'default_input':
|
||||
data.dataset_query_variable = 'default_input'
|
||||
break
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('add-annotation-question-field-value', help='add annotation question value')
|
||||
def add_annotation_question_field_value():
|
||||
click.echo(click.style('Start add annotation question value.', fg='green'))
|
||||
message_annotations = db.session.query(MessageAnnotation).all()
|
||||
message_annotation_deal_count = 0
|
||||
if message_annotations:
|
||||
for message_annotation in message_annotations:
|
||||
try:
|
||||
if message_annotation.message_id and not message_annotation.question:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_annotation.message_id
|
||||
).first()
|
||||
message_annotation.question = message.query
|
||||
db.session.add(message_annotation)
|
||||
db.session.commit()
|
||||
message_annotation_deal_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(generate_invitation_codes)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(recreate_all_dataset_indexes)
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
app.cli.add_command(clean_unused_dataset_indexes)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(update_qdrant_indexes)
|
||||
app.cli.add_command(update_app_model_configs)
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
app.cli.add_command(add_qdrant_full_text_index)
|
||||
app.cli.add_command(add_annotation_question_field_value)
|
||||
|
||||
@@ -40,17 +40,11 @@ DEFAULTS = {
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_OPENAI_PAID_MIN_QUANTITY': 1,
|
||||
'HOSTED_OPENAI_PAID_MAX_QUANTITY': 1,
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
'HOSTED_ANTHROPIC_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 1,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 1,
|
||||
'HOSTED_MODERATION_ENABLED': 'False',
|
||||
'HOSTED_MODERATION_PROVIDERS': '',
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
@@ -93,7 +87,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.4.9"
|
||||
self.CURRENT_VERSION = "0.5.3"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -262,10 +256,6 @@ class Config:
|
||||
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_OPENAI_PAID_MIN_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_OPENAI_PAID_MAX_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
@@ -277,10 +267,6 @@ class Config:
|
||||
self.HOSTED_ANTHROPIC_TRIAL_ENABLED = get_bool_env('HOSTED_ANTHROPIC_TRIAL_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_MINIMAX_ENABLED = get_bool_env('HOSTED_MINIMAX_ENABLED')
|
||||
self.HOSTED_SPARK_ENABLED = get_bool_env('HOSTED_SPARK_ENABLED')
|
||||
|
||||
327
api/constants/languages.py
Normal file
327
api/constants/languages.py
Normal file
@@ -0,0 +1,327 @@
|
||||
|
||||
import json
|
||||
|
||||
from models.model import AppModelConfig
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
'zh-Hans': 'Asia/Shanghai',
|
||||
'pt-BR': 'America/Sao_Paulo',
|
||||
'es-ES': 'Europe/Madrid',
|
||||
'fr-FR': 'Europe/Paris',
|
||||
'de-DE': 'Europe/Berlin',
|
||||
'ja-JP': 'Asia/Tokyo',
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
}
|
||||
|
||||
def supported_language(lang):
|
||||
if lang in languages:
|
||||
return lang
|
||||
|
||||
error = ('{lang} is not a valid language.'
|
||||
.format(lang=lang))
|
||||
raise ValueError(error)
|
||||
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"zh-Hans": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "查询内容",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"pt-BR": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Consulta",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"es-ES": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Consulta",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
demo_model_templates = {
|
||||
'en-US': [
|
||||
{
|
||||
'name': 'Translation Assistant',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'A multilingual translator that provides translation capabilities in multiple languages, translating user input into the language they need.',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "Please translate the following text into {{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Target Language",
|
||||
"description": "The language you want to translate into.",
|
||||
"type": "select",
|
||||
"default": "Chinese",
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Target Language",
|
||||
"variable": "target_language",
|
||||
"description": "The language you want to translate into.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'AI Front-end Interviewer',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'A simulated front-end interviewer that tests the skill level of front-end development through questioning.',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': 'Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ',
|
||||
'prompt_template': "You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ',
|
||||
suggested_questions=None,
|
||||
pre_prompt="You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
|
||||
'zh-Hans': [
|
||||
{
|
||||
'name': '翻译助手',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': '一个多语言翻译器,提供多种语言翻译能力,将用户输入的文本翻译成他们需要的语言。',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "目标语言",
|
||||
"description": "翻译的目标语言",
|
||||
"type": "select",
|
||||
"default": "中文",
|
||||
"options": [
|
||||
"中文",
|
||||
"英文",
|
||||
"日语",
|
||||
"法语",
|
||||
"俄语",
|
||||
"德语",
|
||||
"西班牙语",
|
||||
"韩语",
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "目标语言",
|
||||
"variable": "target_language",
|
||||
"description": "翻译的目标语言",
|
||||
"default": "中文",
|
||||
"required": True,
|
||||
'options': [
|
||||
"中文",
|
||||
"英文",
|
||||
"日语",
|
||||
"法语",
|
||||
"俄语",
|
||||
"德语",
|
||||
"西班牙语",
|
||||
"韩语",
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'AI 前端面试官',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': '一个模拟的前端面试官,通过提问的方式对前端开发的技能水平进行检验。',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': '你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。',
|
||||
'prompt_template': "你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。',
|
||||
suggested_questions=None,
|
||||
pre_prompt="你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
|
||||
}
|
||||
@@ -96,258 +96,3 @@ model_templates = {
|
||||
}
|
||||
|
||||
|
||||
demo_model_templates = {
|
||||
'en-US': [
|
||||
{
|
||||
'name': 'Translation Assistant',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'A multilingual translator that provides translation capabilities in multiple languages, translating user input into the language they need.',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "Please translate the following text into {{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Target Language",
|
||||
"description": "The language you want to translate into.",
|
||||
"type": "select",
|
||||
"default": "Chinese",
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Target Language",
|
||||
"variable": "target_language",
|
||||
"description": "The language you want to translate into.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'AI Front-end Interviewer',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'A simulated front-end interviewer that tests the skill level of front-end development through questioning.',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': 'Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ',
|
||||
'prompt_template': "You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ',
|
||||
suggested_questions=None,
|
||||
pre_prompt="You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
|
||||
'zh-Hans': [
|
||||
{
|
||||
'name': '翻译助手',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': '一个多语言翻译器,提供多种语言翻译能力,将用户输入的文本翻译成他们需要的语言。',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "目标语言",
|
||||
"description": "翻译的目标语言",
|
||||
"type": "select",
|
||||
"default": "中文",
|
||||
"options": [
|
||||
"中文",
|
||||
"英文",
|
||||
"日语",
|
||||
"法语",
|
||||
"俄语",
|
||||
"德语",
|
||||
"西班牙语",
|
||||
"韩语",
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "目标语言",
|
||||
"variable": "target_language",
|
||||
"description": "翻译的目标语言",
|
||||
"default": "中文",
|
||||
"required": True,
|
||||
'options': [
|
||||
"中文",
|
||||
"英文",
|
||||
"日语",
|
||||
"法语",
|
||||
"俄语",
|
||||
"德语",
|
||||
"西班牙语",
|
||||
"韩语",
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'AI 前端面试官',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': '一个模拟的前端面试官,通过提问的方式对前端开发的技能水平进行检验。',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': '你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。',
|
||||
'prompt_template': "你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。',
|
||||
suggested_questions=None,
|
||||
pre_prompt="你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -11,12 +11,11 @@ from .app import (advanced_prompt_template, annotation, app, audio, completion,
|
||||
model_config, site, statistic)
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_oauth, login, oauth
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
# Import explore controllers
|
||||
from .explore import audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import audio, chat, conversation, message, parameter
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import supported_language
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
|
||||
@@ -61,9 +61,7 @@ class BaseApiKeyListResource(Resource):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id,
|
||||
self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
@@ -102,7 +100,7 @@ class BaseApiKeyResource(Resource):
|
||||
self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.query(ApiToken). \
|
||||
|
||||
@@ -21,7 +21,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -45,7 +45,7 @@ class AppAnnotationSettingDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -59,7 +59,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -80,7 +80,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
@@ -108,7 +108,7 @@ class AnnotationListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
@@ -133,7 +133,7 @@ class AnnotationExportApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -152,7 +152,7 @@ class AnnotationCreateApi(Resource):
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -172,7 +172,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -205,7 +205,7 @@ class AnnotationBatchImportApi(Resource):
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
@@ -230,7 +230,7 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
@@ -257,7 +257,7 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id, annotation_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from constants.model_template import demo_model_templates, model_templates
|
||||
from constants.languages import demo_model_templates, languages
|
||||
from constants.model_template import model_templates
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
@@ -16,10 +17,12 @@ from events.app_event import app_was_created, app_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (app_detail_fields, app_detail_fields_with_site, app_pagination_fields,
|
||||
template_list_fields)
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, inputs, marshal_with, reqparse
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from models.tools import ApiToolProvider
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -42,14 +45,31 @@ class AppListApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'completion', 'all'], default='all', location='args', required=False)
|
||||
parser.add_argument('name', type=str, location='args', required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
filters = [
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == False
|
||||
]
|
||||
|
||||
if args['mode'] == 'completion':
|
||||
filters.append(App.mode == 'completion')
|
||||
elif args['mode'] == 'chat':
|
||||
filters.append(App.mode == 'chat')
|
||||
else:
|
||||
pass
|
||||
|
||||
if 'name' in args and args['name']:
|
||||
filters.append(App.name.ilike(f'%{args["name"]}%'))
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == False).order_by(App.created_at.desc()),
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False)
|
||||
error_out=False
|
||||
)
|
||||
|
||||
return app_models
|
||||
|
||||
@@ -62,14 +82,14 @@ class AppListApi(Resource):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('mode', type=str, choices=['completion', 'chat'], location='json')
|
||||
parser.add_argument('mode', type=str, choices=['completion', 'chat', 'assistant'], location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('model_config', type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -88,20 +108,33 @@ class AppListApi(Resource):
|
||||
# validate config
|
||||
model_config_dict = args['model_config']
|
||||
|
||||
# get model provider
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
# Get provider configurations
|
||||
provider_manager = ProviderManager()
|
||||
provider_configurations = provider_manager.get_configurations(current_user.current_tenant_id)
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
model_type=ModelType.LLM,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
# check if model is available
|
||||
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
|
||||
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
|
||||
if provider_model not in available_models_names:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@@ -178,7 +211,7 @@ class AppListApi(Resource):
|
||||
app_was_created.send(app)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
|
||||
class AppTemplateApi(Resource):
|
||||
|
||||
@@ -193,7 +226,7 @@ class AppTemplateApi(Resource):
|
||||
|
||||
templates = demo_model_templates.get(interface_language)
|
||||
if not templates:
|
||||
templates = demo_model_templates.get('en-US')
|
||||
templates = demo_model_templates.get(languages[0])
|
||||
|
||||
return {'data': templates}
|
||||
|
||||
@@ -218,7 +251,7 @@ class AppApi(Resource):
|
||||
"""Delete app"""
|
||||
app_id = str(app_id)
|
||||
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
@@ -32,9 +32,9 @@ class ChatMessageAudioApi(Resource):
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
file=file
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -62,6 +62,48 @@ class ChatMessageAudioApi(Resource):
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
|
||||
class ChatMessageTextApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, None)
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
streaming=False
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
|
||||
@@ -157,7 +157,7 @@ class MessageAnnotationApi(Resource):
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
@@ -7,7 +8,6 @@ from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from libs.helper import supported_language
|
||||
from libs.login import login_required
|
||||
from models.model import Site
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@@ -42,7 +42,7 @@ class AppSite(Resource):
|
||||
app_model = _get_app(app_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site). \
|
||||
@@ -88,7 +88,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
app_model = _get_app(app_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
|
||||
|
||||
@@ -2,11 +2,12 @@ import base64
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import email, str_len, supported_language, timezone
|
||||
from libs.helper import email, str_len, timezone
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus, Tenant
|
||||
from services.account_service import RegisterService
|
||||
|
||||
@@ -30,7 +30,7 @@ def get_oauth_providers():
|
||||
class OAuthDataSource(Resource):
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
@@ -106,11 +107,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
preferred_lang = request.accept_languages.best_match(['zh', 'en'])
|
||||
if preferred_lang == 'zh':
|
||||
interface_language = 'zh-Hans'
|
||||
preferred_lang = request.accept_languages.best_match(languages)
|
||||
if preferred_lang and preferred_lang in languages:
|
||||
interface_language = preferred_lang
|
||||
else:
|
||||
interface_language = 'en-US'
|
||||
interface_language = languages[0]
|
||||
account.interface_language = interface_language
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class Subscription(Resource):
|
||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||
args = parser.parse_args()
|
||||
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
return BillingService.get_subscription(args['plan'],
|
||||
args['interval'],
|
||||
@@ -35,8 +35,8 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
BillingService.is_tenant_owner(current_user)
|
||||
return BillingService.get_invoices(current_user.email)
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
api.add_resource(Subscription, '/billing/subscription')
|
||||
|
||||
@@ -103,7 +103,7 @@ class DatasetListApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -187,7 +187,7 @@ class DatasetApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
dataset = DatasetService.update_dataset(
|
||||
@@ -205,7 +205,7 @@ class DatasetApi(Resource):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
@@ -391,7 +391,7 @@ class DatasetApiKeyApi(Resource):
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
@@ -425,7 +425,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
api_key_id = str(api_key_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.query(ApiToken). \
|
||||
|
||||
@@ -204,7 +204,7 @@ class DatasetDocumentListApi(Resource):
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -256,7 +256,7 @@ class DatasetInitApi(Resource):
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -599,7 +599,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if action == "pause":
|
||||
@@ -663,7 +663,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
doc_metadata = req_data.get('doc_metadata')
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if doc_type is None or doc_metadata is None:
|
||||
@@ -710,7 +710,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
indexing_cache_key = 'document_{}_indexing'.format(document.id)
|
||||
|
||||
@@ -123,7 +123,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -219,7 +219,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
@@ -298,7 +298,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
@@ -342,7 +342,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
@@ -9,7 +9,7 @@ from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
@@ -71,11 +71,7 @@ class FileSupportTypeApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
if etl_type == 'Unstructured':
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
|
||||
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
|
||||
else:
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||
return {'allowed_extensions': allowed_extensions}
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,16 @@ class NotSetupError(BaseHTTPException):
|
||||
"Please proceed with the initialization and installation process first."
|
||||
code = 401
|
||||
|
||||
class NotInitValidateError(BaseHTTPException):
|
||||
error_code = 'not_init_validated'
|
||||
description = "Init validation has not been completed yet. " \
|
||||
"Please proceed with the init validation process first."
|
||||
code = 401
|
||||
|
||||
class InitValidateFailedError(BaseHTTPException):
|
||||
error_code = 'init_validate_failed'
|
||||
description = "Init validation failed. Please check the password and try again."
|
||||
code = 401
|
||||
|
||||
class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = 'account_not_link_tenant'
|
||||
|
||||
@@ -29,9 +29,10 @@ class ChatAudioApi(InstalledAppResource):
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
end_user=None
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -59,6 +60,48 @@ class ChatAudioApi(InstalledAppResource):
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.text_to_speech_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
streaming=False
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
||||
|
||||
@@ -33,8 +33,9 @@ class InstalledAppsListApi(Resource):
|
||||
'app_owner_tenant_id': installed_app.app_owner_tenant_id,
|
||||
'is_pinned': installed_app.is_pinned,
|
||||
'last_used_at': installed_app.last_used_at,
|
||||
"editable": current_user.role in ["owner", "admin"],
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id
|
||||
'editable': current_user.role in ["owner", "admin"],
|
||||
'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
'is_agent': installed_app.is_agent
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
]
|
||||
|
||||
@@ -17,9 +17,9 @@ from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import uuid_value
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@@ -29,7 +29,6 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class MessageListApi(InstalledAppResource):
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -51,7 +50,6 @@ class MessageListApi(InstalledAppResource):
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
def post(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import InstalledApp
|
||||
from models.model import AppModelConfig, InstalledApp
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
@@ -27,6 +31,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'text_to_speech': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
@@ -47,6 +52,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'text_to_speech': app_model_config.text_to_speech_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
@@ -58,5 +64,42 @@ class AppParameterApi(InstalledAppResource):
|
||||
}
|
||||
}
|
||||
|
||||
class ExploreAppMetaApi(InstalledAppResource):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = installed_app.app.app_model_config
|
||||
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return meta
|
||||
|
||||
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters', endpoint='installed_app_parameters')
|
||||
api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@@ -29,7 +30,8 @@ recommended_app_fields = {
|
||||
'is_listed': fields.Boolean,
|
||||
'install_count': fields.Integer,
|
||||
'installed': fields.Boolean,
|
||||
'editable': fields.Boolean
|
||||
'editable': fields.Boolean,
|
||||
'is_agent': fields.Boolean
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
@@ -43,7 +45,7 @@ class RecommendedAppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_fields)
|
||||
def get(self):
|
||||
language_prefix = current_user.interface_language if current_user.interface_language else 'en-US'
|
||||
language_prefix = current_user.interface_language if current_user.interface_language else languages[0]
|
||||
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
@@ -82,6 +84,7 @@ class RecommendedAppListApi(Resource):
|
||||
'install_count': recommended_app.install_count,
|
||||
'installed': installed,
|
||||
'editable': current_user.role in ['owner', 'admin'],
|
||||
"is_agent": app.is_agent
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
|
||||
@@ -3,10 +3,12 @@ from flask_restful import Resource
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api
|
||||
from .wraps import cloud_utm_record
|
||||
|
||||
|
||||
class FeatureApi(Resource):
|
||||
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
return FeatureService.get_features(current_user.current_tenant_id).dict()
|
||||
|
||||
|
||||
48
api/controllers/console/init_validate.py
Normal file
48
api/controllers/console/init_validate.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
|
||||
from flask import current_app, session
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import str_len
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
from . import api
|
||||
from .error import AlreadySetupError, InitValidateFailedError
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
class InitValidateAPI(Resource):
|
||||
|
||||
def get(self):
|
||||
init_status = get_init_validate_status()
|
||||
if init_status:
|
||||
return { 'status': 'finished' }
|
||||
return {'status': 'not_started' }
|
||||
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
# is tenant created
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('password', type=str_len(30),
|
||||
required=True, location='json')
|
||||
input_password = parser.parse_args()['password']
|
||||
|
||||
if input_password != os.environ.get('INIT_PASSWORD'):
|
||||
session['is_init_validated'] = False
|
||||
raise InitValidateFailedError()
|
||||
|
||||
session['is_init_validated'] = True
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
def get_init_validate_status():
|
||||
if current_app.config['EDITION'] == 'SELF_HOSTED':
|
||||
if os.environ.get('INIT_PASSWORD'):
|
||||
return session.get('is_init_validated') or DifySetup.query.first()
|
||||
|
||||
return True
|
||||
|
||||
api.add_resource(InitValidateAPI, '/init')
|
||||
@@ -10,7 +10,8 @@ from models.model import DifySetup
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
||||
from . import api
|
||||
from .error import AlreadySetupError, NotSetupError
|
||||
from .error import AlreadySetupError, NotInitValidateError, NotSetupError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
@@ -24,7 +25,7 @@ class SetupApi(Resource):
|
||||
'step': 'finished',
|
||||
'setup_at': setup_status.setup_at.isoformat()
|
||||
}
|
||||
return {'step': 'not_start'}
|
||||
return {'step': 'not_started'}
|
||||
return {'step': 'finished'}
|
||||
|
||||
@only_edition_self_hosted
|
||||
@@ -37,6 +38,9 @@ class SetupApi(Resource):
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=email,
|
||||
@@ -71,7 +75,10 @@ def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if not get_setup_status():
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
elif not get_setup_status():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError,
|
||||
NoAudioUploadedError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError, UnsupportedAudioTypeError)
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import request
|
||||
from models.model import AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError)
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
|
||||
class UniversalChatAudioApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text')
|
||||
@@ -1,120 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class UniversalChatApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
parser.add_argument('tools', type=list, required=True, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
# update app model config
|
||||
args['model_config'] = app_model_config.to_dict()
|
||||
args['model_config']['model']['name'] = args['model']
|
||||
args['model_config']['model']['provider'] = args['provider']
|
||||
args['model_config']['agent_mode']['tools'] = args['tools']
|
||||
|
||||
if not args['model_config']['agent_mode']['tools']:
|
||||
args['model_config']['agent_mode']['tools'] = [
|
||||
{
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
]
|
||||
else:
|
||||
args['model_config']['agent_mode']['tools'].append({
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
})
|
||||
|
||||
args['inputs'] = {}
|
||||
|
||||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True,
|
||||
is_model_config_override=True,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class UniversalChatStopApi(UniversalChatResource):
|
||||
def post(self, universal_app, task_id):
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
|
||||
api.add_resource(UniversalChatApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')
|
||||
@@ -1,110 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from fields.conversation_fields import (conversation_with_model_config_fields,
|
||||
conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class UniversalChatConversationListApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if 'pinned' in args and args['pinned'] is not None:
|
||||
pinned = True if args['pinned'] == 'true' else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationApi(UniversalChatResource):
|
||||
def delete(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_with_model_config_fields)
|
||||
def post(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationPinApi(UniversalChatResource):
|
||||
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class UniversalChatConversationUnPinApi(UniversalChatResource):
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
|
||||
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
|
||||
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
|
||||
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
|
||||
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')
|
||||
@@ -1,145 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderQuotaExceededError)
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class UniversalChatMessageListApi(UniversalChatResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
'id': fields.String,
|
||||
'chain_id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatMessageFeedbackApi(UniversalChatResource):
|
||||
def post(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||
def get(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {'data': questions}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')
|
||||
@@ -1,38 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import App
|
||||
|
||||
|
||||
class UniversalChatParameterApi(UniversalChatResource):
|
||||
"""Resource for app variables."""
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, universal_app: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = universal_app
|
||||
app_model_config = app_model.app_model_config
|
||||
app_model_config.retriever_resource = json.dumps({'enabled': True})
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')
|
||||
@@ -1,86 +0,0 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
def universal_chat_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# get universal chat app
|
||||
universal_app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == True
|
||||
).first()
|
||||
|
||||
if universal_app is None:
|
||||
# create universal app if not exists
|
||||
universal_app = App(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name='Universal Chat',
|
||||
mode='chat',
|
||||
is_universal=True,
|
||||
icon='',
|
||||
icon_background='',
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
status='normal'
|
||||
)
|
||||
|
||||
db.session.add(universal_app)
|
||||
db.session.flush()
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
provider="",
|
||||
model_id="",
|
||||
configs={},
|
||||
opening_statement='',
|
||||
suggested_questions=json.dumps([]),
|
||||
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||
speech_to_text=json.dumps({'enabled': True}),
|
||||
retriever_resource=json.dumps({'enabled': True}),
|
||||
more_like_this=None,
|
||||
sensitive_word_avoidance=None,
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-16k",
|
||||
"completion_params": {
|
||||
"max_tokens": 800,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([]),
|
||||
pre_prompt='',
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
|
||||
)
|
||||
|
||||
app_model_config.app_id = universal_app.id
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
universal_app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
return view(universal_app, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
class UniversalChatResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]
|
||||
@@ -2,6 +2,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import (AccountAlreadyInitedError, CurrentPasswordIncorrectError,
|
||||
@@ -11,7 +12,7 @@ from extensions.ext_database import db
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from libs.helper import TimestampField, supported_language, timezone
|
||||
from libs.helper import TimestampField, timezone
|
||||
from libs.login import login_required
|
||||
from models.account import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
|
||||
@@ -6,10 +6,10 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal, marshal_with, reqparse
|
||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.account import Account
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
account_fields = {
|
||||
@@ -51,10 +51,12 @@ class MemberInviteEmailApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('emails', type=str, required=True, location='json', action='append')
|
||||
parser.add_argument('role', type=str, required=True, default='admin', location='json')
|
||||
parser.add_argument('language', type=str, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
invitee_emails = args['emails']
|
||||
invitee_role = args['role']
|
||||
interface_language = args['language']
|
||||
if invitee_role not in ['admin', 'normal']:
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
@@ -63,19 +65,12 @@ class MemberInviteEmailApi(Resource):
|
||||
console_web_url = current_app.config.get("CONSOLE_WEB_URL")
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == invitee_email).first()
|
||||
account, role = account
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
|
||||
})
|
||||
account = marshal(account, account_fields)
|
||||
account['role'] = role
|
||||
except Exception as e:
|
||||
invitation_results.append({
|
||||
'status': 'failed',
|
||||
|
||||
@@ -98,7 +98,7 @@ class ModelProviderApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -122,7 +122,7 @@ class ModelProviderApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@@ -159,7 +159,7 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
@@ -186,10 +186,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
def get(self, provider: str):
|
||||
if provider != 'anthropic':
|
||||
raise ValueError(f'provider name {provider} is invalid')
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
data = BillingService.get_model_provider_payment_link(provider_name=provider,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account_id=current_user.id)
|
||||
account_id=current_user.id,
|
||||
prefilled_email=current_user.email)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -1,136 +1,291 @@
|
||||
import io
|
||||
import json
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.provider.tool_provider_service import ToolProviderService
|
||||
from extensions.ext_database import db
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.login import login_required
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
from services.tools_manage_service import ToolManageService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_credential_dict = {}
|
||||
for tool_name in ToolProviderName:
|
||||
tool_credential_dict[tool_name.value] = {
|
||||
'tool_name': tool_name.value,
|
||||
'is_enabled': False,
|
||||
'credentials': None
|
||||
}
|
||||
return ToolManageService.list_tool_providers(user_id, tenant_id)
|
||||
|
||||
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
|
||||
class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
for p in tool_providers:
|
||||
if p.is_enabled:
|
||||
tool_credential_dict[p.tool_name] = {
|
||||
'tool_name': p.tool_name,
|
||||
'is_enabled': p.is_enabled,
|
||||
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
|
||||
}
|
||||
|
||||
return list(tool_credential_dict.values())
|
||||
|
||||
|
||||
class ToolProviderCredentialsApi(Resource):
|
||||
return ToolManageService.list_builtin_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
|
||||
f'current_role is {current_user.current_tenant.current_role}')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
tool_provider_model = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == tenant.id,
|
||||
ToolProvider.tool_name == provider,
|
||||
).first()
|
||||
|
||||
# Only allow updating token for CUSTOM provider type
|
||||
if tool_provider_model:
|
||||
tool_provider_model.encrypted_credentials = encrypted_credentials
|
||||
tool_provider_model.is_enabled = True
|
||||
else:
|
||||
tool_provider_model = ToolProvider(
|
||||
tenant_id=tenant.id,
|
||||
tool_name=provider,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
is_enabled=True
|
||||
)
|
||||
db.session.add(tool_provider_model)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ToolProviderCredentialsValidateApi(Resource):
|
||||
|
||||
return ToolManageService.delete_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
return ToolManageService.update_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
args['credentials'],
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
|
||||
|
||||
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
args = parser.parse_args()
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
return ToolManageService.create_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
args['icon'],
|
||||
args['credentials'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
args.get('privacy_policy', ''),
|
||||
)
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
return response
|
||||
parser.add_argument('url', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.get_api_tool_provider_remote_schema(
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
args['url'],
|
||||
)
|
||||
|
||||
class ToolApiProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.list_api_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.update_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
args['original_provider'],
|
||||
args['icon'],
|
||||
args['credentials'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
args['privacy_policy'],
|
||||
)
|
||||
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.delete_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.get_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return ToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.parser_api_schema(
|
||||
schema=args['schema'],
|
||||
)
|
||||
|
||||
class ToolApiProviderPreviousTestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args['tool_name'],
|
||||
args['credentials'],
|
||||
args['parameters'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
)
|
||||
|
||||
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
||||
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
|
||||
api.add_resource(ToolProviderCredentialsValidateApi,
|
||||
'/workspaces/current/tool-providers/<provider>/credentials-validate')
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
||||
api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update')
|
||||
api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete')
|
||||
api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get')
|
||||
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
|
||||
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from flask import abort, current_app
|
||||
from flask import abort, current_app, request
|
||||
from flask_login import current_user
|
||||
from services.feature_service import FeatureService
|
||||
from services.operation_service import OperationService
|
||||
|
||||
|
||||
def account_initialization_required(view):
|
||||
@@ -73,3 +75,20 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
return decorated
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
try:
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get('utm_info')
|
||||
|
||||
if utm_info:
|
||||
utm_info = json.loads(utm_info)
|
||||
OperationService.record_utm(current_user.current_tenant_id, utm_info)
|
||||
except Exception as e:
|
||||
pass
|
||||
return view(*args, **kwargs)
|
||||
return decorated
|
||||
|
||||
@@ -6,4 +6,4 @@ bp = Blueprint('files', __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview
|
||||
from . import image_preview, tool_files
|
||||
|
||||
47
api/controllers/files/tool_files.py
Normal file
47
api/controllers/files/tool_files.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from controllers.files import api
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.exception import BaseHTTPException
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
||||
class ToolFilePreviewApi(Resource):
|
||||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('timestamp', type=str, required=True, location='args')
|
||||
parser.add_argument('nonce', type=str, required=True, location='args')
|
||||
parser.add_argument('sign', type=str, required=True, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not ToolFileManager.verify_file(file_id=file_id,
|
||||
timestamp=args['timestamp'],
|
||||
nonce=args['nonce'],
|
||||
sign=args['sign'],
|
||||
):
|
||||
raise Forbidden('Invalid request.')
|
||||
|
||||
try:
|
||||
result = ToolFileManager.get_file_generator_by_message_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise NotFound(f'file is not found')
|
||||
|
||||
generator, mimetype = result
|
||||
except Exception:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
api.add_resource(ToolFilePreviewApi, '/files/tools/<uuid:file_id>.<string:extension>')
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
@@ -1,9 +1,13 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import App
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
@@ -28,6 +32,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'text_to_speech': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
@@ -47,6 +52,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'text_to_speech': app_model_config.text_to_speech_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
@@ -58,5 +64,42 @@ class AppParameterApi(AppApiResource):
|
||||
}
|
||||
}
|
||||
|
||||
class AppMetaApi(AppApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return meta
|
||||
|
||||
api.add_resource(AppParameterApi, '/parameters')
|
||||
api.add_resource(AppMetaApi, '/meta')
|
||||
|
||||
@@ -10,6 +10,7 @@ from controllers.service_api.wraps import AppApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from models.model import App, AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError,
|
||||
@@ -22,14 +23,15 @@ class AudioApi(AppApiResource):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
end_user=end_user
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -57,5 +59,50 @@ class AudioApi(AppApiResource):
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
|
||||
|
||||
class TextApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
api.add_resource(TextApi, '/text-to-audio')
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import Response, stream_with_context, request
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
@@ -75,18 +75,22 @@ class CompletionApi(AppApiResource):
|
||||
|
||||
|
||||
class CompletionStopApi(AppApiResource):
|
||||
def post(self, app_model, _, task_id):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
end_user_id = args.get('user')
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -146,13 +150,22 @@ class ChatApi(AppApiResource):
|
||||
|
||||
|
||||
class ChatStopApi(AppApiResource):
|
||||
def post(self, app_model, _, task_id):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
end_user_id = request.get_json().get('user')
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
@@ -37,6 +37,20 @@ class MessageListApi(AppApiResource):
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
'id': fields.String,
|
||||
'chain_id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_labels': fields.Raw,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'observation': fields.String,
|
||||
'message_files': fields.List(fields.String, attribute='files')
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
@@ -46,7 +60,8 @@ class MessageListApi(AppApiResource):
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from models.dataset import Dataset
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
@@ -9,6 +8,7 @@ from fields.dataset_fields import dataset_detail_fields
|
||||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from controllers.service_api import api
|
||||
from flask import current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.service_api import api
|
||||
|
||||
|
||||
class IndexApi(Resource):
|
||||
def get(self):
|
||||
|
||||
@@ -75,8 +75,8 @@ def validate_dataset_token(view=None):
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == api_token.tenant_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role.in_(['owner', 'admin'])) \
|
||||
.one_or_none()
|
||||
.filter(TenantAccountJoin.role.in_(['owner'])) \
|
||||
.one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
@@ -86,9 +86,9 @@ def validate_dataset_token(view=None):
|
||||
current_app.login_manager._update_request_context_with_user(account)
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account is not exist.")
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant is not exist.")
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import App
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
@@ -27,6 +31,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'text_to_speech': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
@@ -46,6 +51,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'text_to_speech': app_model_config.text_to_speech_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
@@ -57,5 +63,42 @@ class AppParameterApi(WebApiResource):
|
||||
}
|
||||
}
|
||||
|
||||
class AppMeta(WebApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return meta
|
||||
|
||||
api.add_resource(AppParameterApi, '/parameters')
|
||||
api.add_resource(AppMeta, '/meta')
|
||||
@@ -28,9 +28,10 @@ class AudioApi(WebApiResource):
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
end_user=end_user
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -59,4 +60,43 @@ class AudioApi(WebApiResource):
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
api.add_resource(TextApi, '/text-to-audio')
|
||||
|
||||
@@ -14,6 +14,7 @@ from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
@@ -59,7 +60,8 @@ class MessageListApi(WebApiResource):
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
|
||||
@@ -13,8 +13,8 @@ from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||
from core.helper import moderation
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.features.agent_runner import AgentRunnerFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentApplicationRunner(AppRunner):
|
||||
"""
|
||||
Agent Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run agent application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Create MessageChain
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message,
|
||||
queue_manager=queue_manager,
|
||||
message_chain=message_chain
|
||||
)
|
||||
|
||||
# init LLM Callback
|
||||
agent_llm_callback = AgentLLMCallback(
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
agent_runner = AgentRunnerFeature(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=app_orchestration_config.agent,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
callback=agent_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# agent run
|
||||
result = agent_runner.run(
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if result:
|
||||
self._save_message_chain(
|
||||
message_chain=message_chain,
|
||||
output_text=result
|
||||
)
|
||||
|
||||
if (result
|
||||
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
|
||||
and app_orchestration_config.prompt_template.simple_prompt_template
|
||||
):
|
||||
# Direct output if agent result exists and has pre prompt
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
stream=application_generate_entity.stream,
|
||||
text=result,
|
||||
usage=self._get_usage_of_all_agent_thoughts(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# As normal LLM run, agent result as context
|
||||
context = result
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_token
|
||||
all_answer_tokens += agent_thought.answer_token
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
||||
@@ -2,7 +2,13 @@ import time
|
||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import AppOrchestrationConfigEntity, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity,
|
||||
ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity,
|
||||
PromptTemplateEntity)
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -11,7 +17,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App
|
||||
from models.model import App, Message, MessageAnnotation
|
||||
|
||||
|
||||
class AppRunner:
|
||||
@@ -199,7 +205,8 @@ class AppRunner:
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: ApplicationQueueManager,
|
||||
stream: bool) -> None:
|
||||
stream: bool,
|
||||
agent: bool = False) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@@ -210,16 +217,19 @@ class AppRunner:
|
||||
if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
else:
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
queue_manager: ApplicationQueueManager,
|
||||
agent: bool) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
@@ -232,7 +242,8 @@ class AppRunner:
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
queue_manager: ApplicationQueueManager,
|
||||
agent: bool) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@@ -244,7 +255,10 @@ class AppRunner:
|
||||
text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||
if not agent:
|
||||
queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
@@ -271,3 +285,101 @@ class AppRunner:
|
||||
llm_result=llm_result,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config_entity: app orchestration config entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = ModerationFeature()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
external_data_fetch_feature = ExternalDataFetchFeature()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
349
api/core/app_runner/assistant_app_runner.py
Normal file
349
api/core/app_runner/assistant_app_runner.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity
|
||||
from core.features.assistant_cot_runner import AssistantCotApplicationRunner
|
||||
from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
Assistant Application Runner
|
||||
"""
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run assistant application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
if query:
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_orchestration_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
agent_entity = app_orchestration_config.agent
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=application_generate_entity.tenant_id)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
prompt_message, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
)
|
||||
invoke_result = assistant_cot_runner.run(
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
)
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
assistant_fc_runner = AssistantFunctionCallApplicationRunner(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
)
|
||||
invoke_result = assistant_fc_runner.run(
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
db.session.add(tool_variables)
|
||||
else:
|
||||
# create new tool variables
|
||||
tool_variables = ToolConversationVariables(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str='[]',
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(**{
|
||||
'conversation_id': db_variables.conversation_id,
|
||||
'user_id': db_variables.user_id,
|
||||
'tenant_id': db_variables.tenant_id,
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
||||
@@ -1,23 +1,17 @@
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity,
|
||||
ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity)
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.moderation.base import ModerationException
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -213,76 +207,6 @@ class BasicApplicationRunner(AppRunner):
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config_entity: app orchestration config entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = ModerationFeature()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
external_data_fetch_feature = ExternalDataFetchFeature()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
def retrieve_dataset_context(self, tenant_id: str,
|
||||
app_record: App,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
@@ -334,31 +258,4 @@ class BasicApplicationRunner(AppRunner):
|
||||
hit_callback=hit_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
|
||||
@@ -6,10 +6,11 @@ from typing import Generator, Optional, Union, cast
|
||||
from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
|
||||
from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
|
||||
QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
|
||||
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent,
|
||||
QueueErrorEvent, QueueMessageEndEvent, QueueMessageEvent,
|
||||
QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent, QueueStopEvent)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
|
||||
PromptMessage, PromptMessageContentType, PromptMessageRole,
|
||||
@@ -18,9 +19,11 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from pydantic import BaseModel
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@@ -279,11 +282,12 @@ class GenerateTaskPipeline:
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
agent_thought = (
|
||||
agent_thought: MessageAgentThought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
@@ -293,16 +297,49 @@ class GenerateTaskPipeline:
|
||||
'message_id': self._message.id,
|
||||
'position': agent_thought.position,
|
||||
'thought': agent_thought.thought,
|
||||
'observation': agent_thought.observation,
|
||||
'tool': agent_thought.tool,
|
||||
'tool_labels': agent_thought.tool_labels,
|
||||
'tool_input': agent_thought.tool_input,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
'created_at': int(self._message.created_at.timestamp()),
|
||||
'message_files': agent_thought.files
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageEvent):
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
message_file: MessageFile = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
else:
|
||||
extension = '.bin'
|
||||
# add sign url
|
||||
url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
|
||||
|
||||
if message_file:
|
||||
response = {
|
||||
'event': 'message_file',
|
||||
'id': message_file.id,
|
||||
'type': message_file.type,
|
||||
'belongs_to': message_file.belongs_to or 'user',
|
||||
'url': url
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
||||
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
@@ -332,7 +369,7 @@ class GenerateTaskPipeline:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
response = self._handle_chunk(delta_text)
|
||||
response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent))
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
response = {
|
||||
@@ -384,14 +421,14 @@ class GenerateTaskPipeline:
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
||||
def _handle_chunk(self, text: str) -> dict:
|
||||
def _handle_chunk(self, text: str, agent: bool = False) -> dict:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'event': 'message' if not agent else 'agent_message',
|
||||
'id': self._message.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
@@ -493,6 +530,10 @@ class GenerateTaskPipeline:
|
||||
'score': resource['score'],
|
||||
'content': resource['content'],
|
||||
})
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in self._task_state.metadata:
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||
|
||||
# show usage
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
|
||||
@@ -116,7 +116,7 @@ class OutputModerationHandler(BaseModel):
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
@@ -4,13 +4,14 @@ import threading
|
||||
import uuid
|
||||
from typing import Any, Generator, Optional, Tuple, Union, cast
|
||||
|
||||
from core.app_runner.agent_app_runner import AgentApplicationRunner
|
||||
from core.app_runner.assistant_app_runner import AssistantApplicationRunner
|
||||
from core.app_runner.basic_app_runner import BasicApplicationRunner
|
||||
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
|
||||
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.entities.application_entities import (AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity, AgentEntity, AgentToolEntity,
|
||||
ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity,
|
||||
AdvancedCompletionPromptTemplateEntity, AgentEntity, AgentPromptEntity,
|
||||
AgentToolEntity, ApplicationGenerateEntity,
|
||||
AppOrchestrationConfigEntity, DatasetEntity,
|
||||
DatasetRetrieveConfigEntity, ExternalDataVariableEntity,
|
||||
FileUploadEntity, InvokeFrom, ModelConfigEntity, PromptTemplateEntity,
|
||||
SensitiveWordAvoidanceEntity)
|
||||
@@ -23,6 +24,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
from extensions.ext_database import db
|
||||
from flask import Flask, current_app
|
||||
from models.account import Account
|
||||
@@ -93,6 +95,9 @@ class ApplicationManager:
|
||||
extras=extras
|
||||
)
|
||||
|
||||
if not stream and application_generate_entity.app_orchestration_config_entity.agent:
|
||||
raise ValueError("Agent app is not supported in blocking mode.")
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
@@ -151,7 +156,7 @@ class ApplicationManager:
|
||||
|
||||
if application_generate_entity.app_orchestration_config_entity.agent:
|
||||
# agent app
|
||||
runner = AgentApplicationRunner()
|
||||
runner = AssistantApplicationRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
@@ -354,6 +359,8 @@ class ApplicationManager:
|
||||
|
||||
# external data variables
|
||||
properties['external_data_variables'] = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
@@ -366,6 +373,19 @@ class ApplicationManager:
|
||||
config=external_data_tool['config']
|
||||
)
|
||||
)
|
||||
|
||||
# current external_data_tools
|
||||
for variable in copy_app_model_config_dict.get('user_input_form', []):
|
||||
typ = list(variable.keys())[0]
|
||||
if typ == 'external_data_tool':
|
||||
val = variable[typ]
|
||||
properties['external_data_variables'].append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=val['variable'],
|
||||
type=val['type'],
|
||||
config=val['config']
|
||||
)
|
||||
)
|
||||
|
||||
# show retrieve source
|
||||
show_retrieve_source = False
|
||||
@@ -375,15 +395,65 @@ class ApplicationManager:
|
||||
show_retrieve_source = True
|
||||
|
||||
properties['show_retrieve_source'] = show_retrieve_source
|
||||
|
||||
dataset_ids = []
|
||||
if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
|
||||
datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
|
||||
'strategy': 'router',
|
||||
'datasets': []
|
||||
})
|
||||
|
||||
|
||||
for dataset in datasets.get('datasets', []):
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != 'dataset':
|
||||
continue
|
||||
dataset = dataset['dataset']
|
||||
|
||||
if 'enabled' not in dataset or not dataset['enabled']:
|
||||
continue
|
||||
|
||||
dataset_id = dataset.get('id', None)
|
||||
if dataset_id:
|
||||
dataset_ids.append(dataset_id)
|
||||
else:
|
||||
datasets = {'strategy': 'router', 'datasets': []}
|
||||
|
||||
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
|
||||
'enabled']:
|
||||
agent_dict = copy_app_model_config_dict.get('agent_mode')
|
||||
agent_strategy = agent_dict.get('strategy', 'router')
|
||||
if agent_strategy in ['router', 'react_router']:
|
||||
dataset_ids = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] \
|
||||
and copy_app_model_config_dict['agent_mode']['enabled']:
|
||||
|
||||
agent_dict = copy_app_model_config_dict.get('agent_mode', {})
|
||||
agent_strategy = agent_dict.get('strategy', 'cot')
|
||||
|
||||
if agent_strategy == 'function_call':
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy == 'cot' or agent_strategy == 'react':
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if copy_app_model_config_dict['model']['provider'] == 'openai':
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
'provider_type': tool['provider_type'],
|
||||
'provider_id': tool['provider_id'],
|
||||
'tool_name': tool['tool_name'],
|
||||
'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
elif len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
@@ -396,59 +466,60 @@ class ApplicationManager:
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
||||
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
single_strategy=agent_strategy
|
||||
)
|
||||
|
||||
if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
|
||||
copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
|
||||
agent_prompt = agent_dict.get('prompt', None) or {}
|
||||
# check model mode
|
||||
model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
|
||||
if model_mode == 'completion':
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
|
||||
)
|
||||
else:
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
top_k=dataset_configs.get('top_k'),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model')
|
||||
)
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
||||
)
|
||||
else:
|
||||
if agent_strategy == 'react':
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
key = list(tool.keys())[0]
|
||||
tool_item = tool[key]
|
||||
|
||||
agent_tool_properties = {
|
||||
"tool_id": key
|
||||
}
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties["config"] = tool_item
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
||||
properties['agent'] = AgentEntity(
|
||||
provider=properties['model_config'].provider,
|
||||
model=properties['model_config'].model,
|
||||
strategy=strategy,
|
||||
tools=agent_tools
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get('max_iteration', 5)
|
||||
)
|
||||
|
||||
if len(dataset_ids) > 0:
|
||||
# dataset configs
|
||||
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
||||
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
||||
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
single_strategy=datasets.get('strategy', 'router')
|
||||
)
|
||||
)
|
||||
else:
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
top_k=dataset_configs.get('top_k'),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model')
|
||||
)
|
||||
)
|
||||
|
||||
# file upload
|
||||
@@ -485,6 +556,12 @@ class ApplicationManager:
|
||||
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
|
||||
properties['speech_to_text'] = True
|
||||
|
||||
# text to speech
|
||||
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
|
||||
properties['text_to_speech'] = True
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
if sensitive_word_avoidance_dict:
|
||||
@@ -601,6 +678,7 @@ class ApplicationManager:
|
||||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
belongs_to='user',
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if account_id else 'end_user'),
|
||||
|
||||
@@ -4,13 +4,13 @@ from enum import Enum
|
||||
from typing import Any, Generator
|
||||
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentThoughtEvent, QueueErrorEvent,
|
||||
QueueMessage, QueueMessageEndEvent, QueueMessageEvent,
|
||||
QueueMessageReplaceEvent, QueuePingEvent, QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent)
|
||||
from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent,
|
||||
QueueAgentThoughtEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent,
|
||||
QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent,
|
||||
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAgentThought
|
||||
from models.model import MessageAgentThought, MessageFile
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
|
||||
@@ -96,6 +96,18 @@ class ApplicationQueueManager:
|
||||
chunk=chunk
|
||||
), pub_from)
|
||||
|
||||
def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent chunk message to channel
|
||||
|
||||
:param chunk: chunk
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentMessageEvent(
|
||||
chunk=chunk
|
||||
), pub_from)
|
||||
|
||||
def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish message replace
|
||||
@@ -144,6 +156,17 @@ class ApplicationQueueManager:
|
||||
agent_thought_id=message_agent_thought.id
|
||||
), pub_from)
|
||||
|
||||
def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_file: message file
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), pub_from)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
|
||||
75
api/core/callback_handler/agent_tool_callback_handler.py
Normal file
75
api/core/callback_handler/agent_tool_callback_handler.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
"""Callback Handler that prints to std out."""
|
||||
color: Optional[str] = ''
|
||||
current_loop = 1
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
"""Initialize callback handler."""
|
||||
# use a specific color is not specified
|
||||
self.color = color or 'green'
|
||||
self.current_loop = 1
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_outputs: str,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color=self.color)
|
||||
print_text("Tool: " + tool_name + "\n", color=self.color)
|
||||
print_text("Inputs: " + str(tool_inputs) + "\n", color=self.color)
|
||||
print_text("Outputs: " + str(tool_outputs) + "\n", color=self.color)
|
||||
print_text("\n")
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
|
||||
|
||||
def on_agent_start(
|
||||
self, thought: str
|
||||
) -> None:
|
||||
"""Run on agent start."""
|
||||
if thought:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + \
|
||||
str(self.current_loop) + \
|
||||
"\nThought: " + thought + "\n", color=self.color)
|
||||
else:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
def on_agent_finish(
|
||||
self, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
self.current_loop += 1
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
@@ -27,7 +27,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]:
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
@@ -36,7 +36,7 @@ class FileExtractor:
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
@@ -52,7 +52,7 @@ class FileExtractor:
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[List[Document] | str]:
|
||||
is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
@@ -68,7 +68,7 @@ class FileExtractor:
|
||||
else MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
@@ -95,7 +95,7 @@ class FileExtractor:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTXLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -8,9 +8,8 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from extensions.ext_database import db
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileObj
|
||||
@@ -153,9 +153,35 @@ class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
tool_id: str
|
||||
config: dict[str, Any] = {}
|
||||
provider_type: Literal["builtin", "api"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Agent First Prompt Entity.
|
||||
"""
|
||||
|
||||
class Action(BaseModel):
|
||||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
@@ -171,8 +197,9 @@ class AgentEntity(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
strategy: Strategy
|
||||
tools: list[AgentToolEntity] = []
|
||||
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
class AppOrchestrationConfigEntity(BaseModel):
|
||||
"""
|
||||
@@ -191,6 +218,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: bool = False
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
@@ -255,7 +283,6 @@ class ApplicationGenerateEntity(BaseModel):
|
||||
query: Optional[str] = None
|
||||
files: list[FileObj] = []
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
stream: bool
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
@@ -153,8 +153,16 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if provider_record:
|
||||
try:
|
||||
original_credentials = json.loads(
|
||||
provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
||||
# fix origin data
|
||||
if provider_record.encrypted_config:
|
||||
if not provider_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {
|
||||
"openai_api_key": provider_record.encrypted_config
|
||||
}
|
||||
else:
|
||||
original_credentials = json.loads(provider_record.encrypted_config)
|
||||
else:
|
||||
original_credentials = {}
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel
|
||||
class QuotaUnit(Enum):
|
||||
TIMES = 'times'
|
||||
TOKENS = 'tokens'
|
||||
CREDITS = 'credits'
|
||||
|
||||
|
||||
class SystemConfigurationStatus(Enum):
|
||||
|
||||
@@ -10,11 +10,13 @@ class QueueEvent(Enum):
|
||||
QueueEvent enum
|
||||
"""
|
||||
MESSAGE = "message"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
MESSAGE_REPLACE = "message-replace"
|
||||
MESSAGE_END = "message-end"
|
||||
RETRIEVER_RESOURCES = "retriever-resources"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
MESSAGE_FILE = "message-file"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
@@ -33,7 +35,14 @@ class QueueMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
event = QueueEvent.MESSAGE
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEvent entity
|
||||
"""
|
||||
event = QueueEvent.AGENT_MESSAGE
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
class QueueMessageReplaceEvent(AppQueueEvent):
|
||||
"""
|
||||
@@ -73,7 +82,13 @@ class QueueAgentThoughtEvent(AppQueueEvent):
|
||||
"""
|
||||
event = QueueEvent.AGENT_THOUGHT
|
||||
agent_thought_id: str
|
||||
|
||||
|
||||
class QueueMessageFileEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE_FILE
|
||||
message_file_id: str
|
||||
|
||||
class QueueErrorEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@@ -13,11 +13,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tool.current_datetime_tool import DatetimeTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIInput, OptimizedSerpAPIWrapper
|
||||
from core.tool.web_reader_tool import WebReaderTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
@@ -132,55 +128,6 @@ class AgentRunnerFeature:
|
||||
logger.exception("agent_executor run failed")
|
||||
return None
|
||||
|
||||
def to_tools(self, tool_configs: list[AgentToolEntity],
|
||||
invoke_from: InvokeFrom,
|
||||
callbacks: list[BaseCallbackHandler]) \
|
||||
-> Optional[List[BaseTool]]:
|
||||
"""
|
||||
Convert tool configs to tools
|
||||
:param tool_configs: tool configs
|
||||
:param invoke_from: invoke from
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
tools = []
|
||||
for tool_config in tool_configs:
|
||||
tool = None
|
||||
if tool_config.tool_id == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "web_reader":
|
||||
tool = self.to_web_reader_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "google_search":
|
||||
tool = self.to_google_search_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "wikipedia":
|
||||
tool = self.to_wikipedia_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "current_datetime":
|
||||
tool = self.to_current_datetime_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
if tool:
|
||||
if tool.callbacks is not None:
|
||||
tool.callbacks.extend(callbacks)
|
||||
else:
|
||||
tool.callbacks = callbacks
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) \
|
||||
-> Optional[BaseTool]:
|
||||
@@ -247,78 +194,4 @@ class AgentRunnerFeature:
|
||||
retriever_from=invoke_from.to_source()
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_web_reader_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for reading web pages
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
model_parameters = {
|
||||
"temperature": 0,
|
||||
"max_tokens": 500
|
||||
}
|
||||
|
||||
tool = WebReaderTool(
|
||||
model_config=self.model_config,
|
||||
model_parameters=model_parameters,
|
||||
max_chunk_length=4000,
|
||||
continue_reading=True
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_google_search_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for performing a Google search and extracting snippets and webpages
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
||||
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
||||
if not func_kwargs:
|
||||
return None
|
||||
|
||||
tool = Tool(
|
||||
name="google_search",
|
||||
description="A tool for performing a Google search and extracting snippets and webpages "
|
||||
"when you need to search for something you don't know or when your information "
|
||||
"is not up to date. "
|
||||
"Input should be a search query.",
|
||||
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
|
||||
args_schema=OptimizedSerpAPIInput
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_current_datetime_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for getting the current date and time
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
return DatetimeTool()
|
||||
|
||||
def to_wikipedia_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for searching Wikipedia
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
class WikipediaInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
return WikipediaQueryRun(
|
||||
name="wikipedia",
|
||||
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||
args_schema=WikipediaInput
|
||||
)
|
||||
return tool
|
||||
582
api/core/features/assistant_base_runner.py
Normal file
582
api/core/features/assistant_base_runner.py
Normal file
@@ -0,0 +1,582 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import (AgentEntity, AgentToolEntity, ApplicationGenerateEntity,
|
||||
AppOrchestrationConfigEntity, InvokeFrom, ModelConfigEntity)
|
||||
from core.file.message_file_parser import FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.tool_entities import (ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter,
|
||||
ToolRuntimeVariablePool)
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageAgentThought, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAssistantApplicationRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: ApplicationGenerateEntity,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
model_config: ModelConfigEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.app_orchestration_config = app_orchestration_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = prompt_messages
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
self.agent_callback = DifyAgentCallbackHandler()
|
||||
# init dataset tools
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=queue_manager,
|
||||
app_id=self.application_generate_entity.app_id,
|
||||
message_id=message.id,
|
||||
user_id=user_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [],
|
||||
retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None,
|
||||
return_resource=app_orchestration_config.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
|
||||
self.stream_tool_call = True
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
|
||||
"""
|
||||
Repack app orchestration config
|
||||
"""
|
||||
if app_orchestration_config.prompt_template.simple_prompt_template is None:
|
||||
app_orchestration_config.prompt_template.simple_prompt_template = ''
|
||||
|
||||
return app_orchestration_config
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ''
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += f"image has been created and sent to user already, you should tell user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
|
||||
tenant_id=self.application_generate_entity.tenant_id,
|
||||
agent_callback=self.agent_callback
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_entity.parameters or []
|
||||
user_parameters = tool_entity.get_runtime_parameters() or []
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
for parameter in parameters:
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name,
|
||||
description=tool.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
for parameter in tool.get_runtime_parameters():
|
||||
parameter_type = 'string'
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||
"""
|
||||
update prompt message tool
|
||||
"""
|
||||
# try to get tool runtime parameters
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and 'mime_type' in response.meta:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:param messages: messages
|
||||
:return: message files, should save as variable
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
file_type = 'bin'
|
||||
if 'image' in message.mimetype:
|
||||
file_type = 'image'
|
||||
elif 'video' in message.mimetype:
|
||||
file_type = 'video'
|
||||
elif 'audio' in message.mimetype:
|
||||
file_type = 'audio'
|
||||
elif 'text' in message.mimetype:
|
||||
file_type = 'text'
|
||||
elif 'pdf' in message.mimetype:
|
||||
file_type = 'pdf'
|
||||
elif 'zip' in message.mimetype:
|
||||
file_type = 'archive'
|
||||
# ...
|
||||
|
||||
invoke_from = self.application_generate_entity.invoke_from
|
||||
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE.value,
|
||||
belongs_to='assistant',
|
||||
url=message.url,
|
||||
upload_file_id=None,
|
||||
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self.user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: List[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
"""
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_files=json.dumps(messages_ids) if messages_ids else '',
|
||||
answer='',
|
||||
observation='',
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
position=self.agent_thought_count + 1,
|
||||
currency='USD',
|
||||
latency=0,
|
||||
created_by_role='account',
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
db.session.add(thought)
|
||||
db.session.commit()
|
||||
|
||||
self.agent_thought_count += 1
|
||||
|
||||
return thought
|
||||
|
||||
def save_agent_thought(self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: str,
|
||||
answer: str,
|
||||
messages_ids: List[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
||||
if tool_name is not None:
|
||||
agent_thought.tool = tool_name
|
||||
|
||||
if tool_input is not None:
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
|
||||
if observation is not None:
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
if tool not in labels:
|
||||
tool_label = ToolManager.get_tool_label(tool)
|
||||
if tool_label:
|
||||
labels[tool] = tool_label.to_dict()
|
||||
else:
|
||||
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages
|
||||
"""
|
||||
if self.history_prompt_messages is None:
|
||||
self.history_prompt_messages = db.session.query(PromptMessage).filter(
|
||||
PromptMessage.message_id == self.message.id,
|
||||
).order_by(PromptMessage.position.asc()).all()
|
||||
|
||||
return self.history_prompt_messages
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message into agent thought
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# try to download image
|
||||
try:
|
||||
file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
|
||||
conversation_id=self.message.conversation_id,
|
||||
file_url=message.message)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
mimetype = message.meta.get('mime_type', 'octet/stream')
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode('utf-8')
|
||||
file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
|
||||
conversation_id=self.message.conversation_id,
|
||||
file_binary=message.message,
|
||||
mimetype=mimetype)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
579
api/core/features/assistant_cot_runner.py
Normal file
579
api/core/features/assistant_cot_runner.py
Normal file
@@ -0,0 +1,579 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Generator, List, Literal, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
||||
SystemPromptMessage, UserPromptMessage)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError, ToolProviderNotFoundError)
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
self._repack_app_orchestration_config(app_orchestration_config)
|
||||
|
||||
agent_scratchpad: List[AgentScratchpadUnit] = []
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
# TODO: stop words
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
app_orchestration_config.model_config.stop.append('Observation')
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
|
||||
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt messages
|
||||
prompt_messages = self._organize_cot_prompt_messages(
|
||||
mode=app_orchestration_config.model_config.mode,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_messages_tools,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
agent_prompt_message=app_orchestration_config.agent.prompt,
|
||||
instruction=app_orchestration_config.prompt_template.simple_prompt_template,
|
||||
input=query
|
||||
)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
llm_result: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=[],
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not llm_result:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
# get scratchpad
|
||||
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if llm_result.usage:
|
||||
increase_usage(llm_usage, llm_result.usage)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=llm_result.message.content,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_result.usage)
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# publish agent thought if it's not empty and there is a action
|
||||
if scratchpad.thought and scratchpad.action:
|
||||
# check if final answer
|
||||
if not scratchpad.action.action_name.lower() == "final answer":
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=scratchpad.thought
|
||||
),
|
||||
usage=llm_result.usage,
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = scratchpad.agent_response or ''
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
final_answer = scratchpad.action.action_input if \
|
||||
isinstance(scratchpad.action.action_input, str) else \
|
||||
json.dumps(scratchpad.action.action_input)
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
else:
|
||||
function_call_state = True
|
||||
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = scratchpad.action.action_name
|
||||
tool_call_args = scratchpad.action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
thought=None,
|
||||
observation=answer,
|
||||
answer=answer,
|
||||
messages_ids=[])
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
tool_response = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
|
||||
)
|
||||
# transform tool response to llm friendly response
|
||||
tool_response = self.transform_tool_invoke_messages(tool_response)
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = self.extract_tool_response_binary(tool_response)
|
||||
# create message file
|
||||
message_files = self.create_message_files(binary_files)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name,
|
||||
value=message_file.id,
|
||||
name=save_as)
|
||||
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool_call_name}"
|
||||
except (
|
||||
ToolParameterValidationError
|
||||
) as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
|
||||
if error_response:
|
||||
observation = error_response
|
||||
else:
|
||||
observation = self._convert_tool_response_to_str(tool_response)
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
scratchpad.agent_response = llm_result.message.content
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_name,
|
||||
tool_input=tool_call_args,
|
||||
thought=None,
|
||||
observation=observation,
|
||||
answer=llm_result.message.content,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
thought=final_answer,
|
||||
observation='',
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish_message_end(LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
||||
"""
|
||||
extract response from llm response
|
||||
"""
|
||||
def extra_quotes() -> AgentScratchpadUnit:
|
||||
agent_response = content
|
||||
# try to extract all quotes
|
||||
pattern = re.compile(r'```(.*?)```', re.DOTALL)
|
||||
quotes = pattern.findall(content)
|
||||
|
||||
# try to extract action from end to start
|
||||
for i in range(len(quotes) - 1, 0, -1):
|
||||
"""
|
||||
1. use json load to parse action
|
||||
2. use plain text `Action: xxx` to parse action
|
||||
"""
|
||||
try:
|
||||
action = json.loads(quotes[i].replace('```', ''))
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
# try to parse action from plain text
|
||||
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
|
||||
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
|
||||
# delete action from agent response
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
# remove extra quotes
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name[0],
|
||||
action_input=action_input[0],
|
||||
)
|
||||
)
|
||||
|
||||
def extra_json():
|
||||
agent_response = content
|
||||
# try to extract all json
|
||||
structures, pair_match_stack = [], []
|
||||
started_at, end_at = 0, 0
|
||||
for i in range(len(content)):
|
||||
if content[i] == '{':
|
||||
pair_match_stack.append(i)
|
||||
if len(pair_match_stack) == 1:
|
||||
started_at = i
|
||||
elif content[i] == '}':
|
||||
begin = pair_match_stack.pop()
|
||||
if not pair_match_stack:
|
||||
end_at = i + 1
|
||||
structures.append((content[begin:i+1], (started_at, end_at)))
|
||||
|
||||
# handle the last character
|
||||
if pair_match_stack:
|
||||
end_at = len(content)
|
||||
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
|
||||
|
||||
for i in range(len(structures), 0, -1):
|
||||
try:
|
||||
json_content, (started_at, end_at) = structures[i - 1]
|
||||
action = json.loads(json_content)
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
# delete json content from agent response
|
||||
agent_thought = agent_response[:started_at] + agent_response[end_at:]
|
||||
# remove extra quotes like ```(json)*\n\n```
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input is not None:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=json_content,
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad = extra_quotes()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
agent_scratchpad = extra_json()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=content,
|
||||
action_str='',
|
||||
action=None
|
||||
)
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
):
|
||||
"""
|
||||
check chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
next_iteration = agent_prompt_message.next_iteration
|
||||
|
||||
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
|
||||
raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
|
||||
|
||||
# check instruction, tools, and tool_names slots
|
||||
if not first_prompt.find("{{instruction}}") >= 0:
|
||||
raise ValueError("{{instruction}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tools}}") >= 0:
|
||||
raise ValueError("{{tools}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tool_names}}") >= 0:
|
||||
raise ValueError("{{tool_names}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not first_prompt.find("{{query}}") >= 0:
|
||||
raise ValueError("{{query}} is required in first_prompt")
|
||||
if not first_prompt.find("{{agent_scratchpad}}") >= 0:
|
||||
raise ValueError("{{agent_scratchpad}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
|
||||
|
||||
return result
|
||||
|
||||
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
agent_scratchpad: List[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> List[PromptMessage]:
|
||||
"""
|
||||
organize chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
self._check_cot_prompt_messages(mode, agent_prompt_message)
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
|
||||
# parse tools
|
||||
tools_str = self._jsonify_tool_prompt_messages(tools)
|
||||
|
||||
# parse tools name
|
||||
tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
|
||||
|
||||
# get system message
|
||||
system_message = first_prompt.replace("{{instruction}}", instruction) \
|
||||
.replace("{{tools}}", tools_str) \
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
|
||||
# organize prompt messages
|
||||
if mode == "chat":
|
||||
# override system message
|
||||
overrided = False
|
||||
prompt_messages = prompt_messages.copy()
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message.content = system_message
|
||||
overrided = True
|
||||
break
|
||||
|
||||
if not overrided:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=system_message,
|
||||
))
|
||||
|
||||
# add assistant message
|
||||
if len(agent_scratchpad) > 0:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=(agent_scratchpad[-1].thought or '')
|
||||
))
|
||||
|
||||
# add user message
|
||||
if len(agent_scratchpad) > 0:
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=(agent_scratchpad[-1].observation or ''),
|
||||
))
|
||||
|
||||
return prompt_messages
|
||||
elif mode == "completion":
|
||||
# parse agent scratchpad
|
||||
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
|
||||
# parse prompt messages
|
||||
return [UserPromptMessage(
|
||||
content=first_prompt.replace("{{instruction}}", instruction)
|
||||
.replace("{{tools}}", tools_str)
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
.replace("{{query}}", input)
|
||||
.replace("{{agent_scratchpad}}", agent_scratchpad_str),
|
||||
)]
|
||||
else:
|
||||
raise ValueError(f"mode {mode} is not supported")
|
||||
|
||||
def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
|
||||
"""
|
||||
jsonify tool prompt messages
|
||||
"""
|
||||
tools = jsonable_encoder(tools)
|
||||
try:
|
||||
return json.dumps(tools, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps(tools)
|
||||
415
api/core/features/assistant_fc_runner.py
Normal file
415
api/core/features/assistant_fc_runner.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
||||
SystemPromptMessage, ToolPromptMessage, UserPromptMessage)
|
||||
from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError, ToolProviderNotFoundError)
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
|
||||
prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_messages = self.history_prompt_messages
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=query,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: List[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call:
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += chunk.delta.message.content
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result: LLMResult = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
current_llm_usage = result.usage
|
||||
|
||||
if result.message and result.message.content:
|
||||
if isinstance(result.message.content, list):
|
||||
for content in result.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += result.message.content
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ''
|
||||
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
system_fingerprint=result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content='',
|
||||
name='',
|
||||
tool_calls=[AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1],
|
||||
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
)
|
||||
) for tool_call in tool_calls]
|
||||
))
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
)
|
||||
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
|
||||
# update prompt messages
|
||||
if response.strip():
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=response,
|
||||
))
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}"
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
else:
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
tool_invoke_message = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_parameters=tool_call_args,
|
||||
)
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = self.extract_tool_response_binary(tool_invoke_message)
|
||||
# create message file
|
||||
message_files = self.create_message_files(binary_files)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool_call_name}"
|
||||
except (
|
||||
ToolParameterValidationError
|
||||
) as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
|
||||
if error_response:
|
||||
observation = error_response
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": error_response
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
else:
|
||||
observation = self._convert_tool_response_to_str(tool_invoke_message)
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": observation
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=None,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_response=tool_response['tool_response'],
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_response['tool_response'],
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish_message_end(LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer,
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
Check if there is any tool call in llm result chunk
|
||||
"""
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
"""
|
||||
if llm_result.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
json.loads(prompt_message.function.arguments),
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
json.loads(prompt_message.function.arguments),
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
def organize_prompt_messages(self, prompt_template: str,
|
||||
query: str = None,
|
||||
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
else:
|
||||
if tool_response:
|
||||
prompt_messages = prompt_messages.copy()
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
@@ -6,8 +6,8 @@ from core.entities.application_entities import DatasetEntity, DatasetRetrieveCon
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from langchain.tools import BaseTool
|
||||
from models.dataset import Dataset
|
||||
|
||||
@@ -22,6 +22,7 @@ class FileType(enum.Enum):
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = 'remote_url'
|
||||
LOCAL_FILE = 'local_file'
|
||||
TOOL_FILE = 'tool_file'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
@@ -30,6 +31,16 @@ class FileTransferMethod(enum.Enum):
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
class FileBelongsTo(enum.Enum):
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
class FileObj(BaseModel):
|
||||
id: Optional[str]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from core.file.file_obj import FileObj, FileTransferMethod, FileType
|
||||
from core.file.upload_file_parser import SUPPORT_EXTENSIONS
|
||||
from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import AppModelConfig, EndUser, MessageFile, UploadFile
|
||||
from services.file_service import IMAGE_EXTENSIONS
|
||||
|
||||
|
||||
class MessageFileParser:
|
||||
@@ -83,7 +83,7 @@ class MessageFileParser:
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
UploadFile.extension.in_(SUPPORT_EXTENSIONS)
|
||||
UploadFile.extension.in_(IMAGE_EXTENSIONS)
|
||||
).first())
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
@@ -128,6 +128,10 @@ class MessageFileParser:
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
if isinstance(file, MessageFile):
|
||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
||||
continue
|
||||
|
||||
file_obj = self._to_file_obj(file, file_upload_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
8
api/core/file/tool_file_parser.py
Normal file
8
api/core/file/tool_file_parser.py
Normal file
@@ -0,0 +1,8 @@
|
||||
tool_file_manager = {
|
||||
'manager': None
|
||||
}
|
||||
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> 'ToolFileManager':
|
||||
return tool_file_manager['manager']
|
||||
@@ -9,8 +9,8 @@ from typing import Optional
|
||||
from extensions.ext_storage import storage
|
||||
from flask import current_app
|
||||
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
@@ -18,7 +18,7 @@ class UploadFileParser:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in SUPPORT_EXTENSIONS:
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from core.entities.provider_entities import QuotaUnit, RestrictModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from flask import Flask, Config
|
||||
from flask import Config, Flask
|
||||
from models.provider import ProviderQuotaType
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -20,10 +20,6 @@ class TrialHostingQuota(HostingQuota):
|
||||
|
||||
class PaidHostingQuota(HostingQuota):
|
||||
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
|
||||
stripe_price_id: str = None
|
||||
increase_quota: int = 1
|
||||
min_quantity: int = 20
|
||||
max_quantity: int = 100
|
||||
|
||||
|
||||
class FreeHostingQuota(HostingQuota):
|
||||
@@ -102,7 +98,7 @@ class HostingConfiguration:
|
||||
)
|
||||
|
||||
def init_openai(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TIMES
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
|
||||
@@ -114,6 +110,8 @@ class HostingConfiguration:
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
RestrictModel(model="whisper-1", model_type=ModelType.SPEECH2TEXT),
|
||||
]
|
||||
@@ -122,10 +120,20 @@ class HostingConfiguration:
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
|
||||
paid_quota = PaidHostingQuota(
|
||||
stripe_price_id=app_config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||
increase_quota=int(app_config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
|
||||
min_quantity=int(app_config.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
|
||||
max_quantity=int(app_config.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-32k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
]
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
@@ -164,12 +172,7 @@ class HostingConfiguration:
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
|
||||
paid_quota = PaidHostingQuota(
|
||||
stripe_price_id=app_config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
increase_quota=int(app_config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
|
||||
min_quantity=int(app_config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
|
||||
max_quantity=int(app_config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
|
||||
)
|
||||
paid_quota = PaidHostingQuota()
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelManager, ModelInstance
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
@@ -562,7 +562,7 @@ class IndexingRunner:
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=0,
|
||||
chunk_overlap=segmentation.get('chunk_overlap', 0),
|
||||
fixed_separator=separator,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
@@ -571,7 +571,7 @@ class IndexingRunner:
|
||||
# Automatic segmentation
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||
chunk_overlap=0,
|
||||
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
)
|
||||
@@ -655,7 +655,9 @@ class IndexingRunner:
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
|
||||
if document_node.page_content:
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
# processing qa document
|
||||
if document_form == 'qa_model':
|
||||
|
||||
@@ -13,6 +13,7 @@ from core.model_runtime.model_providers.__base.moderation_model import Moderatio
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@@ -144,7 +145,7 @@ class ModelInstance:
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
@@ -161,8 +162,29 @@ class ModelInstance:
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:param streaming: output is streaming
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception(f"Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
**params
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ This module provides the interface for invoking and authenticating various model
|
||||
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
||||
- `Rerank Model` - Segment Rerank capability
|
||||
- `Speech-to-text Model` - Speech to text capability
|
||||
- `Text-to-speech Model` - Text to speech capability
|
||||
- `Moderation` - Moderation capability
|
||||
|
||||
- Model provider display
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
- `Text Embedidng Model` - 文本 Embedding ,预计算 tokens 能力
|
||||
- `Rerank Model` - 分段 Rerank 能力
|
||||
- `Speech-to-text Model` - 语音转文本能力
|
||||
- `Text-to-speech Model` - 文本转语音能力
|
||||
- `Moderation` - Moderation 能力
|
||||
|
||||
- 模型供应商展示
|
||||
|
||||
@@ -299,9 +299,7 @@ Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@@ -331,6 +329,46 @@ Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement
|
||||
|
||||
The string after speech-to-text conversion.
|
||||
|
||||
### Text2speech
|
||||
|
||||
Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement the following interfaces:
|
||||
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `content_text` (string) The text content that needs to be converted
|
||||
|
||||
- `streaming` (bool) Whether to stream output
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
Text converted speech stream。
|
||||
|
||||
### Moderation
|
||||
|
||||
Inherit the `__base.moderation_model.ModerationModel` base class and implement the following interfaces:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user