mirror of
https://github.com/langgenius/dify.git
synced 2026-01-02 04:27:16 +00:00
Compare commits
216 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4678845dd | ||
|
|
174ebb51db | ||
|
|
626c78a690 | ||
|
|
9eaae770a6 | ||
|
|
ca60610306 | ||
|
|
082f8b17ab | ||
|
|
cf93d8d6e2 | ||
|
|
aae2fb8a30 | ||
|
|
23e52f14e3 | ||
|
|
c5b68fb273 | ||
|
|
6f17c9b2fe | ||
|
|
c98311b325 | ||
|
|
d44d4bd6fd | ||
|
|
2adaceab82 | ||
|
|
d979955c8a | ||
|
|
eae670ea4a | ||
|
|
b5825142d1 | ||
|
|
741e9303d4 | ||
|
|
538e3fc256 | ||
|
|
ba3dc8cae0 | ||
|
|
ae7c0380dc | ||
|
|
23e3413655 | ||
|
|
4fdb37771a | ||
|
|
94b54b7ca9 | ||
|
|
f9412f5fdb | ||
|
|
1d6829f400 | ||
|
|
f8bae897e5 | ||
|
|
dd1172b57e | ||
|
|
67d326a558 | ||
|
|
fe747040bc | ||
|
|
7d6c925cbc | ||
|
|
f488d06b20 | ||
|
|
c00a19ced3 | ||
|
|
e9810a6df2 | ||
|
|
cae15013e0 | ||
|
|
52c84da051 | ||
|
|
026f0bfce9 | ||
|
|
d19181fb29 | ||
|
|
2f9de2229f | ||
|
|
34f55739e0 | ||
|
|
668b059c07 | ||
|
|
753e5f1500 | ||
|
|
a6af8e5d8f | ||
|
|
3e1d5ac51b | ||
|
|
b0091452ca | ||
|
|
eff115267f | ||
|
|
07cde4f8fe | ||
|
|
9f28a48a92 | ||
|
|
0d3cd3b16a | ||
|
|
3dc82fb044 | ||
|
|
cb6e73347e | ||
|
|
ecd6cbaee6 | ||
|
|
d54e942264 | ||
|
|
28ba721455 | ||
|
|
784dd7848e | ||
|
|
e2a5f8ba1a | ||
|
|
8e11200306 | ||
|
|
7599f79a17 | ||
|
|
510389909c | ||
|
|
2c6e00174b | ||
|
|
24f3456990 | ||
|
|
20514ff288 | ||
|
|
381d255290 | ||
|
|
7f320f9146 | ||
|
|
cd51d3323b | ||
|
|
004b3caa43 | ||
|
|
dbe10799e3 | ||
|
|
054ba88434 | ||
|
|
da82a11b26 | ||
|
|
fec607db81 | ||
|
|
397a92f2ee | ||
|
|
b91e226063 | ||
|
|
da5782df92 | ||
|
|
9af0da4450 | ||
|
|
d49ac1e4ac | ||
|
|
57de19a5ca | ||
|
|
7c00a0b6a3 | ||
|
|
a93506df18 | ||
|
|
a03a92e9db | ||
|
|
feebb5dd1f | ||
|
|
6eee7cb42c | ||
|
|
11baff6740 | ||
|
|
cde1797cc0 | ||
|
|
d143284d99 | ||
|
|
2b94545190 | ||
|
|
ed6648a41e | ||
|
|
5e2c3eeac3 | ||
|
|
b23d8a912b | ||
|
|
4f13f8fd0a | ||
|
|
561c9cabd5 | ||
|
|
39ea967b30 | ||
|
|
da04ff040b | ||
|
|
b9b0866a46 | ||
|
|
c6ab7eebd9 | ||
|
|
db4e6d81c5 | ||
|
|
df68a7c82b | ||
|
|
838825d747 | ||
|
|
a87f6f2837 | ||
|
|
9d98669e7d | ||
|
|
408fbb0c70 | ||
|
|
998f819b04 | ||
|
|
6194b82752 | ||
|
|
334f46d0b6 | ||
|
|
2eea114ac0 | ||
|
|
97e9ebd29a | ||
|
|
ec261aea54 | ||
|
|
accc5faae3 | ||
|
|
0462f09ecc | ||
|
|
1226d73159 | ||
|
|
c67ecff3fe | ||
|
|
d5b42c09ee | ||
|
|
835bf9fd8d | ||
|
|
c720f831af | ||
|
|
df5763be37 | ||
|
|
80eebc2414 | ||
|
|
17d196126c | ||
|
|
addf150a9e | ||
|
|
cad1532f7c | ||
|
|
951afcaaed | ||
|
|
3241e4015b | ||
|
|
1dee5de9b4 | ||
|
|
742bad93b5 | ||
|
|
bb3cc6bba6 | ||
|
|
23ef2262bd | ||
|
|
d637a147ee | ||
|
|
8a4d19d9ba | ||
|
|
bea382f0dc | ||
|
|
8b39e48957 | ||
|
|
5b4538f021 | ||
|
|
36dc05c4da | ||
|
|
54f3bbbf47 | ||
|
|
f797fab206 | ||
|
|
ce2996e7d4 | ||
|
|
82d07ed2a8 | ||
|
|
c39d8f954e | ||
|
|
226f28edcb | ||
|
|
402b0b81d2 | ||
|
|
b08c19d926 | ||
|
|
9253f72dea | ||
|
|
f350948bde | ||
|
|
eeb2c28526 | ||
|
|
673288d58e | ||
|
|
772d67fd65 | ||
|
|
7552a6be36 | ||
|
|
33200090e8 | ||
|
|
01a6c725fa | ||
|
|
f6e04389e4 | ||
|
|
e22814b291 | ||
|
|
a66ef7210b | ||
|
|
184afa69ff | ||
|
|
ab115b5f87 | ||
|
|
3bbc4ad3db | ||
|
|
87af414a52 | ||
|
|
72555d5df8 | ||
|
|
fff39a307a | ||
|
|
a11f36ca60 | ||
|
|
433f8cb57e | ||
|
|
cd136fb293 | ||
|
|
6a3ab36101 | ||
|
|
1af968e73a | ||
|
|
94646f29c3 | ||
|
|
e028a0595c | ||
|
|
b16a7b0b3b | ||
|
|
e083a7067b | ||
|
|
205459d54d | ||
|
|
3d14431b96 | ||
|
|
2ba0ee989a | ||
|
|
b055470147 | ||
|
|
5943385d42 | ||
|
|
0abd67288b | ||
|
|
bbe58327c8 | ||
|
|
299c51ebc4 | ||
|
|
3a7f58d2a6 | ||
|
|
6123bba96d | ||
|
|
d5ab3b5072 | ||
|
|
df26f82536 | ||
|
|
dbe0c43515 | ||
|
|
f4052fdbc7 | ||
|
|
b5ade19c75 | ||
|
|
040eacb8bd | ||
|
|
20899c44ff | ||
|
|
35a2beb195 | ||
|
|
2056093855 | ||
|
|
2bf48514bc | ||
|
|
c109b1a920 | ||
|
|
45499328b8 | ||
|
|
4c61aa399d | ||
|
|
3e380c082a | ||
|
|
53db5bab36 | ||
|
|
6483beb096 | ||
|
|
e61c84ca72 | ||
|
|
d70086b841 | ||
|
|
a3ee037d6d | ||
|
|
2de18a6490 | ||
|
|
4134e915ce | ||
|
|
a838ba7b46 | ||
|
|
5f38214a41 | ||
|
|
19b5cb1e10 | ||
|
|
2478c88e07 | ||
|
|
59e59c19b2 | ||
|
|
c67f626b66 | ||
|
|
f65a3ad1cc | ||
|
|
490858a4d5 | ||
|
|
44a1aa5e44 | ||
|
|
a616bf3129 | ||
|
|
f2f19484b8 | ||
|
|
f572b55237 | ||
|
|
554570dc22 | ||
|
|
5239b2c7ab | ||
|
|
ae94b067b3 | ||
|
|
5e772bd10b | ||
|
|
91bcbd0b26 | ||
|
|
54bb309d87 | ||
|
|
75f7a96025 | ||
|
|
ccd80653ff | ||
|
|
5ca88a4fd9 |
@@ -19,7 +19,7 @@ def check_file_for_chinese_comments(file_path):
|
||||
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py']
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
|
||||
30
.github/workflows/stale.yml
vendored
Normal file
30
.github/workflows/stale.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
||||
#
|
||||
# You can adjust the behavior by modifying this file.
|
||||
# For more information, see:
|
||||
# https://github.com/actions/stale
|
||||
name: Mark stale issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 3 * * *'
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 3
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement'
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -109,6 +109,7 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
@@ -130,6 +131,7 @@ dmypy.json
|
||||
.idea/'
|
||||
|
||||
.DS_Store
|
||||
web/.vscode/settings.json
|
||||
|
||||
# Intellij IDEA Files
|
||||
.idea/
|
||||
@@ -146,3 +148,5 @@ docker/volumes/weaviate/*
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/
|
||||
@@ -54,3 +54,8 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
|
||||
## Community channels
|
||||
|
||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/AhzKf7dNgk). We are here to help!
|
||||
|
||||
### i18n (Internationalization) Support
|
||||
|
||||
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
|
||||
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
|
||||
@@ -51,3 +51,7 @@ git clone git@github.com:<github_username>/dify.git
|
||||
## 社区渠道
|
||||
|
||||
遇到困难了吗?有任何问题吗? 加入 [Discord Community Server](https://discord.gg/AhzKf7dNgk),我们将为您提供帮助。
|
||||
|
||||
### 多语言支持
|
||||
|
||||
需要参与贡献翻译内容,请参阅[前端多语言翻译 README](web/i18n/README_CN.md)。
|
||||
|
||||
36
LICENSE
36
LICENSE
@@ -1,26 +1,26 @@
|
||||
# Dify Open Source License
|
||||
|
||||
The Dify project uses a combination of the Apache License 2.0, MIT License, and an additional agreement to protect against direct competition with Dify Cloud services.
|
||||
The Dify project is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
|
||||
As a contributor, you should agree that your contributed code:
|
||||
a. Might be subject to a more permissive open source license in the future.
|
||||
1. Dify is permitted to be used for commercialization, such as using Dify as a "backend-as-a-service" for your other applications, or delivering it to enterprises as an application development platform. However, when the following conditions are met, you must contact the producer to obtain a commercial license:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify.AI source code to operate a multi-tenant SaaS service that is similar to the Dify.AI service edition.
|
||||
b. LOGO and copyright information: In the process of using Dify, you may not remove or modify the LOGO or copyright information in the Dify console.
|
||||
|
||||
Please contact business@dify.ai by email to inquire about licensing matters.
|
||||
|
||||
2. As a contributor, you should agree that your contributed code:
|
||||
|
||||
a. The producer can adjust the open-source agreement to be more strict or relaxed.
|
||||
b. Can be used for commercial purposes, such as Dify's cloud business.
|
||||
|
||||
The following components are open source under the MIT license, allowing you to build and develop applications based on them:
|
||||
- WebApp elements, e.g., web/app/components/share
|
||||
- Derived WebApp Template projects
|
||||
|
||||
The remaining parts of the project are open source under the Apache License 2.0.
|
||||
|
||||
With the Apache License 2.0, MIT License, and this supplementary agreement, anyone can freely use, modify, and distribute Dify, provided that:
|
||||
|
||||
- If you use Dify solely as a backend service for other applications, no authorization is needed for commercial or closed source purposes.
|
||||
- If you wish to use Dify for commercial and closed source SaaS services similar to Dify Cloud, please contact us for authorization.
|
||||
Apart from this, all other rights and restrictions follow the Apache License 2.0. If you need more detailed information, you can refer to the full version of Apache License 2.0.
|
||||
|
||||
The interactive design of this product is protected by appearance patent.
|
||||
|
||||
© 2023 LangGenius, Inc.
|
||||
|
||||
|
||||
----------
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -34,13 +34,3 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
----------
|
||||
The MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
57
README.md
57
README.md
@@ -2,14 +2,12 @@
|
||||
<p align="center">
|
||||
<a href="./README.md">English</a> |
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a>
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a>
|
||||
</p>
|
||||
|
||||
[Website](https://dify.ai) • [Docs](https://docs.dify.ai) • [Twitter](https://twitter.com/dify_ai) • [Discord](https://discord.gg/FngNHpbcY7)
|
||||
|
||||
Vote for us on Product Hunt ↓
|
||||
<a href="https://www.producthunt.com/posts/dify-ai"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?sanitize=true&post_id=dify-ai&theme=light" alt="Product Hunt Badge" width="250" height="54"></a>
|
||||
|
||||
**Dify** is an easy-to-use LLMOps platform designed to empower more people to create sustainable, AI-native applications. With visual orchestration for various application types, Dify offers out-of-the-box, ready-to-use applications that can also serve as Backend-as-a-Service APIs. Unify your development process with one API for plugins and datasets integration, and streamline your operations using a single interface for prompt engineering, visual analytics, and continuous improvement.
|
||||
|
||||
Applications created with Dify include:
|
||||
@@ -19,9 +17,15 @@ A single API encompassing plugin capabilities, context enhancement, and more, sa
|
||||
Visual data analysis, log review, and annotation for applications
|
||||
Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported:
|
||||
|
||||
- GPT 3 (text-davinci-003)
|
||||
- GPT 3.5 Turbo(ChatGPT)
|
||||
- GPT-4
|
||||
* **OpenAI** :GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
|
||||
|
||||
* **Azure OpenAI**
|
||||
|
||||
* **Antropic**:Claude2、Claude-instant
|
||||
> We've got 1000 free trial credits available for all cloud service users to try out the Claude model.Visit [Dify.ai](https://dify.ai) and
|
||||
try it now.
|
||||
|
||||
* **hugging face Hub**:Coming soon.
|
||||
|
||||
## Use Cloud Services
|
||||
|
||||
@@ -42,11 +46,16 @@ The easiest way to start the Dify server is to run our [docker-compose.yml](dock
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization installation process.
|
||||
|
||||
### Helm Chart
|
||||
|
||||
A big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
|
||||
You can go to https://github.com/BorisPolonsky/dify-helm for deployment information.
|
||||
|
||||
### Configuration
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run 'docker-compose up -d' again.
|
||||
@@ -85,6 +94,32 @@ A: English and Chinese are currently supported, and you can contribute language
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome you to contribute to Dify to help make Dify better. We welcome contributions in various ways, submitting code, issues, new ideas, or sharing the interesting and useful AI applications you have created based on Dify. At the same time, we also welcome you to share Dify at different events, conferences, and social media.
|
||||
|
||||
### Submit a Pull Request
|
||||
|
||||
To ensure proper review, all code contributions, including from contributors with direct commit access, must be submitted as PR requests and approved by core developers before merging branches.
|
||||
We welcome PRs from everyone! If you're willing to help out, you can learn more about how to contribute code to the project in the [Contribution Guide](CONTRIBUTING.md).
|
||||
|
||||
### Submit issues or ideas
|
||||
|
||||
You can submit your issues or ideas by adding issues to the Dify repository. If you encounter issues, please describe the steps you took to encounter the issue as much as possible so we can better discover it. If you have any new ideas for our product, we also welcome your feedback. Please share your insights as much as possible so we can get more feedback and further discussion in the community.
|
||||
|
||||
### Share your applications
|
||||
|
||||
We encourage all community members to share their AI applications built on Dify, which can be applied to different scenarios or different users. This will provide powerful inspiration for people who want to create AI capabilities! You can share your experience by [submitting an issue in the Dify-user-case repository](https://github.com/langgenius/dify-user-case/issues).
|
||||
|
||||
### Share Dify with others
|
||||
|
||||
We encourage community contributors to actively demonstrate different aspects of using Dify. You can talk or share any feature of using Dify at meetups and conferences, blogs or social media. We believe your unique sharing will be of great help to others! Mention @Dify.AI on Twitter and/or communicate on [Discord](https://discord.gg/FngNHpbcY7) so we can give pointers and tips and help you spread the word by promoting your content on the different Dify communication channels.
|
||||
|
||||
### Help others
|
||||
You can also help people in need of help on Discord, GitHub issues or other social platforms, guide others to solve problems encountered during use and share usage experiences. This is also a great contribution! If you want to become a maintainer of the Dify community, please contact the official team via [Discord](https://discord.gg/FngNHpbcY7) or email us at support@dify.ai.
|
||||
|
||||
|
||||
## Contact Us
|
||||
|
||||
If you have any questions, suggestions, or partnership inquiries, feel free to contact us through the following channels:
|
||||
@@ -95,12 +130,6 @@ If you have any questions, suggestions, or partnership inquiries, feel free to c
|
||||
|
||||
We're eager to assist you and together create more fun and useful AI applications!
|
||||
|
||||
## Contributing
|
||||
|
||||
To ensure proper review, all code contributions - including those from contributors with direct commit access - must be submitted via pull requests and approved by the core development team prior to being merged.
|
||||
|
||||
We welcome all pull requests! If you'd like to help, check out the [Contribution Guide](CONTRIBUTING.md) for more information on how to get started.
|
||||
|
||||
## Security
|
||||
|
||||
To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer.
|
||||
|
||||
55
README_CN.md
55
README_CN.md
@@ -2,15 +2,13 @@
|
||||
<p align="center">
|
||||
<a href="./README.md">English</a> |
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a>
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a>
|
||||
</p>
|
||||
|
||||
|
||||
[官方网站](https://dify.ai) • [文档](https://docs.dify.ai/v/zh-hans) • [Twitter](https://twitter.com/dify_ai) • [Discord](https://discord.gg/FngNHpbcY7)
|
||||
|
||||
在 Product Hunt 上投我们一票吧 ↓
|
||||
<a href="https://www.producthunt.com/posts/dify-ai"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?sanitize=true&post_id=dify-ai&theme=light" alt="Product Hunt Badge" width="250" height="54"></a>
|
||||
|
||||
**Dify** 是一个易用的 LLMOps 平台,旨在让更多人可以创建可持续运营的原生 AI 应用。Dify 提供多种类型应用的可视化编排,应用可开箱即用,也能以“后端即服务”的 API 提供服务。
|
||||
|
||||
通过 Dify 创建的应用包含了:
|
||||
@@ -19,11 +17,16 @@
|
||||
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
|
||||
- 可视化的对应用进行数据分析,查阅日志或进行标注
|
||||
|
||||
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前已支持:
|
||||
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:
|
||||
|
||||
- GPT 3 (text-davinci-003)
|
||||
- GPT 3.5 Turbo(ChatGPT)
|
||||
- GPT-4
|
||||
* **OpenAI**:GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
|
||||
|
||||
* **Azure OpenAI Service**
|
||||
* **Anthropic**:Claude2、Claude-instant
|
||||
|
||||
> 我们为所有注册云端版的用户免费提供了 1000 次 Claude 模型的消息调用额度,登录 [dify.ai](https://cloud.dify.ai) 即可使用。
|
||||
|
||||
* **Hugging Face Hub**(即将推出)
|
||||
|
||||
## 使用云服务
|
||||
|
||||
@@ -44,11 +47,16 @@ Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
运行后,可以在浏览器上访问 [http://localhost/install](http://localhost/install) 进入 Dify 控制台并开始初始化安装操作。
|
||||
|
||||
### Helm Chart
|
||||
|
||||
非常感谢 @BorisPolonsky 为我们提供了一个 [Helm Chart](https://helm.sh/) 版本,可以在 Kubernetes 上部署 Dify。
|
||||
您可以前往 https://github.com/BorisPolonsky/dify-helm 来获取部署信息。
|
||||
|
||||
### 配置
|
||||
|
||||
需要自定义配置,请参考我们的 [docker-compose.yml](docker/docker-compose.yaml) 文件中的注释,并手动设置环境配置,修改完毕后,请再次执行 `docker-compose up -d`。
|
||||
@@ -86,6 +94,29 @@ A: 现已支持英文与中文,你可以为我们贡献语言包。
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
|
||||
## 贡献
|
||||
|
||||
我们欢迎你为 Dify 作出贡献帮助 Dify 变得更好。我们欢迎各种方式的贡献,提交代码、问题、新想法、或者分享你基于 Dify 创建出的各种有趣有用的 AI 应用。同时,我们也欢迎你在不同的活动、研讨会、社交媒体上分享 Dify。
|
||||
|
||||
### 贡献代码
|
||||
为了确保正确审查,所有代码贡献 - 包括来自具有直接提交更改权限的贡献者 - 都必须提交 PR 请求并在合并分支之前得到核心开发人员的批准。
|
||||
|
||||
我们欢迎所有人提交 PR!如果您愿意提供帮助,可以在 [贡献指南](CONTRIBUTING_CN.md) 中了解有关如何为项目做出代码贡献的更多信息。
|
||||
|
||||
### 提交问题或想法
|
||||
你可以通过 Dify 代码仓库新增 issues 来提交你的问题或想法。如遇到问题,请尽可能描述你遇到问题的操作步骤,以便我们更好地发现它。如果你对我们的产品有任何新想法,也欢迎向我们反馈,请尽可能多地分享你的见解,以便我们在社区中获得更多反馈和进一步讨论。
|
||||
|
||||
### 分享你的应用
|
||||
我们鼓励所有社区成员分享他们基于 Dify 创造出的 AI 应用,它们可以是应用于不同情景或不同用户,这将有助于为希望基于 AI 能力创造的人们提供强大灵感!你可以通过 [Dify-user-case 仓库项目提交 issue](https://github.com/langgenius/dify-user-case) 来分享你的应用案例。
|
||||
|
||||
### 向别人分享 Dify
|
||||
我们鼓励社区贡献者们积极展示你使用 Dify 的不同角度。你可以通过线下研讨会、博客或社交媒体上谈论或分享你使用 Dify 的任意功能,相信你独特的使用分享会给别人带来非常大的帮助!如果你需要任何指导帮助,欢迎联系我们 support@dify.ai ,你也可以在 twitter @Dify.AI 或在 [Discord 社区](https://discord.gg/FngNHpbcY7)交流来帮助你传播信息。
|
||||
|
||||
### 帮助别人
|
||||
你还可以在 Discord、GitHub issues或其他社交平台上帮助需要帮助的人,指导别人解决使用过程中遇到的问题和分享使用经验。这也是个非常了不起的贡献!如果你希望成为 Dify 社区的维护者,请通过[Discord 社区](https://discord.gg/FngNHpbcY7) 联系官方团队或邮件联系我们 support@dify.ai.
|
||||
|
||||
|
||||
## 联系我们
|
||||
|
||||
如果您有任何问题、建议或合作意向,欢迎通过以下方式联系我们:
|
||||
@@ -94,12 +125,6 @@ A: 现已支持英文与中文,你可以为我们贡献语言包。
|
||||
- 在我们的 [Discord 社区](https://discord.gg/FngNHpbcY7) 上加入讨论
|
||||
- 发送邮件至 hello@dify.ai
|
||||
|
||||
## 贡献代码
|
||||
|
||||
为了确保正确审查,所有代码贡献 - 包括来自具有直接提交更改权限的贡献者 - 都必须提交 PR 请求并在合并分支之前得到核心开发人员的批准。
|
||||
|
||||
我们欢迎所有人提交 PR!如果您愿意提供帮助,可以在 [贡献指南](CONTRIBUTING_CN.md) 中了解有关如何为项目做出贡献的更多信息。
|
||||
|
||||
## 安全
|
||||
|
||||
为了保护您的隐私,请避免在 GitHub 上发布安全问题。发送问题至 security@dify.ai,我们将为您做更细致的解答。
|
||||
|
||||
124
README_ES.md
Normal file
124
README_ES.md
Normal file
@@ -0,0 +1,124 @@
|
||||

|
||||
<p align="center">
|
||||
<a href="./README.md">English</a> |
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a>
|
||||
</p>
|
||||
|
||||
[Sitio web](https://dify.ai) • [Documentación](https://docs.dify.ai) • [Twitter](https://twitter.com/dify_ai) • [Discord](https://discord.gg/FngNHpbcY7)
|
||||
|
||||
**Dify** es una plataforma LLMOps fácil de usar diseñada para capacitar a más personas para que creen aplicaciones sostenibles basadas en IA. Con orquestación visual para varios tipos de aplicaciones, Dify ofrece aplicaciones listas para usar que también pueden funcionar como APIs de Backend-as-a-Service. Unifica tu proceso de desarrollo con una API para la integración de complementos y conjuntos de datos, y agiliza tus operaciones utilizando una interfaz única para la ingeniería de indicaciones, análisis visual y mejora continua.
|
||||
|
||||
Las aplicaciones creadas con Dify incluyen:
|
||||
|
||||
- Sitios web listos para usar que admiten el modo de formulario y el modo de conversación por chat.
|
||||
- Una API única que abarca capacidades de complementos, mejora de contexto y más, lo que te ahorra esfuerzo de programación en el backend.
|
||||
- Análisis visual de datos, revisión de registros y anotación para aplicaciones.
|
||||
|
||||
Dify es compatible con Langchain, lo que significa que gradualmente admitiremos múltiples LLMs, actualmente compatibles con:
|
||||
|
||||
- GPT 3 (text-davinci-003)
|
||||
- GPT 3.5 Turbo (ChatGPT)
|
||||
- GPT-4
|
||||
|
||||
## Usar servicios en la nube
|
||||
|
||||
Visita [Dify.ai](https://dify.ai)
|
||||
|
||||
## Instalar la Edición Comunitaria
|
||||
|
||||
### Requisitos del sistema
|
||||
|
||||
Antes de instalar Dify, asegúrate de que tu máquina cumple con los siguientes requisitos mínimos del sistema:
|
||||
|
||||
- CPU >= 1 Core
|
||||
- RAM >= 4GB
|
||||
|
||||
### Inicio rápido
|
||||
|
||||
La forma más sencilla de iniciar el servidor de Dify es ejecutar nuestro archivo [docker-compose.yml](docker/docker-compose.yaml). Antes de ejecutar el comando de instalación, asegúrate de que [Docker](https://docs.docker.com/get-docker/) y [Docker Compose](https://docs.docker.com/compose/install/) estén instalados en tu máquina:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
Después de ejecutarlo, puedes acceder al panel de control de Dify en tu navegador desde [http://localhost/install](http://localhost/install) y comenzar el proceso de instalación de inicialización.
|
||||
|
||||
### Helm Chart
|
||||
|
||||
Un gran agradecimiento a @BorisPolonsky por proporcionarnos una versión de [Helm Chart](https://helm.sh/), que permite desplegar Dify en Kubernetes.
|
||||
Puede ir a https://github.com/BorisPolonsky/dify-helm para obtener información de despliegue.
|
||||
|
||||
### Configuración
|
||||
|
||||
Si necesitas personalizar la configuración, consulta los comentarios en nuestro archivo [docker-compose.yml](docker/docker-compose.yaml) y configura manualmente la configuración del entorno. Después de realizar los cambios, ejecuta nuevamente 'docker-compose up -d'.
|
||||
|
||||
## Hoja de ruta
|
||||
|
||||
Funciones en desarrollo:
|
||||
|
||||
- **Conjuntos de datos**, admitiendo más conjuntos de datos, por ejemplo, sincronización de contenido desde Notion o páginas web.
|
||||
Admitiremos más conjuntos de datos, incluidos texto, páginas web e incluso contenido de Notion. Los usuarios pueden construir aplicaciones de IA basadas en sus propias fuentes de datos
|
||||
- **Complementos**, introduciendo complementos estándar de ChatGPT para aplicaciones, o utilizando complementos producidos por Dify.
|
||||
Lanzaremos complementos que cumplan con el estándar de ChatGPT, o nuestros propios complementos de Dify para habilitar más capacidades en las aplicaciones.
|
||||
- **Modelos de código abierto**, por ejemplo, adoptar Llama como proveedor de modelos o para un ajuste adicional.
|
||||
Trabajaremos con excelentes modelos de código abierto como Llama, proporcionándolos como opciones de modelos en nuestra plataforma o utilizándolos para un ajuste adicional.
|
||||
|
||||
## Preguntas y respuestas
|
||||
|
||||
**P: ¿Qué puedo hacer con Dify?**
|
||||
|
||||
R: Dify es una herramienta de desarrollo y operaciones de LLM, simple pero poderosa. Puedes usarla para construir aplicaciones de calidad comercial y asistentes personales. Si deseas desarrollar tus propias aplicaciones, LangDifyGenius puede ahorrarte trabajo en el backend al integrar con OpenAI y ofrecer capacidades de operaciones visuales, lo que te permite mejorar y entrenar continuamente tu modelo GPT.
|
||||
|
||||
**P: ¿Cómo uso Dify para "entrenar" mi propio modelo?**
|
||||
|
||||
R: Una aplicación valiosa consta de Ingeniería de indicaciones, mejora de contexto y ajuste fino. Hemos creado un enfoque de programación híbrida que combina las indicaciones con lenguajes de programación (similar a un motor de plantillas), lo que facilita la incorporación de texto largo o la captura de subtítulos de un video de YouTube ingresado por el usuario, todo lo cual se enviará como contexto para que los LLM lo procesen. Damos gran importancia a la operabilidad de la aplicación, con los datos generados por los usuarios durante el uso de la aplicación disponibles para análisis, anotación y entrenamiento continuo. Sin las herramientas adecuadas, estos pasos pueden llevar mucho tiempo.
|
||||
|
||||
**P: ¿Qué necesito preparar si quiero crear mi propia aplicación?**
|
||||
|
||||
R: Suponemos que ya tienes una clave de API de OpenAI; si no la tienes, por favor regístrate. ¡Si ya tienes contenido que pueda servir como contexto de entrenamiento, eso es genial!
|
||||
|
||||
**P: ¿Qué idiomas de interfaz están disponibles?**
|
||||
|
||||
R: Actualmente se admiten inglés y chino, y puedes contribuir con paquetes de idiomas.
|
||||
|
||||
## Historial de estrellas
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
## Contáctanos
|
||||
|
||||
Si tienes alguna pregunta, sugerencia o consulta sobre asociación, no dudes en contactarnos a través de los siguientes canales:
|
||||
|
||||
- Presentar un problema o una solicitud de extracción en nuestro repositorio de GitHub.
|
||||
- Únete a la discusión en nuestra comunidad de [Discord](https://discord.gg/FngNHpbcY7).
|
||||
- Envía un correo electrónico a hello@dify.ai.
|
||||
|
||||
¡Estamos ansiosos por ayudarte y crear juntos aplicaciones de IA más divertidas y útiles!
|
||||
|
||||
## Contribuciones
|
||||
|
||||
Para garantizar una revisión adecuada, todas las contribuciones de código, incluidas las de los colaboradores con acceso directo a los compromisos, deben enviarse mediante solicitudes de extracción y ser aprobadas por el equipo principal de
|
||||
|
||||
desarrollo antes de fusionarse.
|
||||
|
||||
¡Agradecemos todas las solicitudes de extracción! Si deseas ayudar, consulta la [Guía de Contribución](CONTRIBUTING.md) para obtener más información sobre cómo comenzar.
|
||||
|
||||
## Seguridad
|
||||
|
||||
Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En su lugar, envía tus preguntas a security@dify.ai y te proporcionaremos una respuesta más detallada.
|
||||
|
||||
## Citación
|
||||
|
||||
Este software utiliza el siguiente software de código abierto:
|
||||
|
||||
- Chase, H. (2022). LangChain [Software de computadora]. https://github.com/hwchase17/langchain
|
||||
- Liu, J. (2022). LlamaIndex [Software de computadora]. doi: 10.5281/zenodo.1234.
|
||||
|
||||
Para obtener más información, consulta el sitio web oficial o el texto de la licencia del software correspondiente.
|
||||
|
||||
## Licencia
|
||||
|
||||
Este repositorio está disponible bajo la [Licencia de código abierto de Dify](LICENSE).
|
||||
13
README_JA.md
13
README_JA.md
@@ -2,14 +2,12 @@
|
||||
<p align="center">
|
||||
<a href="./README.md">English</a> |
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a>
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a>
|
||||
</p>
|
||||
|
||||
[Web サイト](https://dify.ai) • [ドキュメント](https://docs.dify.ai) • [Twitter](https://twitter.com/dify_ai) • [Discord](https://discord.gg/FngNHpbcY7)
|
||||
|
||||
Product Huntで私たちに投票してください ↓
|
||||
<a href="https://www.producthunt.com/posts/dify-ai"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?sanitize=true&post_id=dify-ai&theme=light" alt="Product Hunt Badge" width="250" height="54"></a>
|
||||
|
||||
|
||||
**Dify** は、より多くの人々が持続可能な AI ネイティブアプリケーションを作成できるように設計された、使いやすい LLMOps プラットフォームです。様々なアプリケーションタイプに対応したビジュアルオーケストレーションにより Dify は Backend-as-a-Service API としても機能する、すぐに使えるアプリケーションを提供します。プラグインやデータセットを統合するための1つの API で開発プロセスを統一し、プロンプトエンジニアリング、ビジュアル分析、継続的な改善のための1つのインターフェイスを使って業務を合理化します。
|
||||
|
||||
@@ -43,11 +41,16 @@ Dify サーバーを起動する最も簡単な方法は、[docker-compose.yml](
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
実行後、ブラウザで [http://localhost/install](http://localhost/install) にアクセスし、初期化インストール作業を開始することができます。
|
||||
|
||||
### Helm Chart
|
||||
|
||||
@BorisPolonsky に大感謝します。彼は Dify を Kubernetes 上にデプロイするための [Helm Chart](https://helm.sh/) バージョンを提供してくれました。
|
||||
デプロイ情報については、https://github.com/BorisPolonsky/dify-helm をご覧ください。
|
||||
|
||||
### 構成
|
||||
|
||||
カスタマイズが必要な場合は、[docker-compose.yml](docker/docker-compose.yaml) ファイルのコメントを参照し、手動で環境設定をお願いします。変更後、再度 'docker-compose up -d' を実行してください。
|
||||
|
||||
@@ -8,13 +8,19 @@ EDITION=SELF_HOSTED
|
||||
SECRET_KEY=
|
||||
|
||||
# Console API base URL
|
||||
CONSOLE_URL=http://127.0.0.1:5001
|
||||
CONSOLE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Console frontend web base URL
|
||||
CONSOLE_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Service API base URL
|
||||
API_URL=http://127.0.0.1:5001
|
||||
SERVICE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP base URL
|
||||
APP_URL=http://127.0.0.1:3000
|
||||
# Web APP API base URL
|
||||
APP_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP frontend web base URL
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
@@ -22,6 +28,7 @@ CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
# redis configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_DB=0
|
||||
|
||||
@@ -72,14 +79,26 @@ VECTOR_STORE=weaviate
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
|
||||
# Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=path:storage/qdrant
|
||||
QDRANT_API_KEY=your-qdrant-api-key
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
# DEBUG
|
||||
DEBUG=false
|
||||
SQLALCHEMY_ECHO=false
|
||||
|
||||
# Notion import configuration, support public and internal
|
||||
NOTION_INTEGRATION_TYPE=public
|
||||
NOTION_CLIENT_SECRET=you-client-secret
|
||||
NOTION_CLIENT_ID=you-client-id
|
||||
NOTION_INTERNAL_SECRET=you-internal-secret
|
||||
|
||||
@@ -5,9 +5,11 @@ LABEL maintainer="takatost@gmail.com"
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
ENV DEPLOY_ENV PRODUCTION
|
||||
ENV CONSOLE_URL http://127.0.0.1:5001
|
||||
ENV API_URL http://127.0.0.1:5001
|
||||
ENV APP_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_API_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_WEB_URL http://127.0.0.1:3000
|
||||
ENV SERVICE_API_URL http://127.0.0.1:5001
|
||||
ENV APP_API_URL http://127.0.0.1:5001
|
||||
ENV APP_WEB_URL http://127.0.0.1:3000
|
||||
|
||||
EXPOSE 5001
|
||||
|
||||
@@ -25,4 +27,4 @@ RUN chmod +x /entrypoint.sh
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA ${COMMIT_SHA}
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
|
||||
@@ -17,6 +17,11 @@
|
||||
```bash
|
||||
openssl rand -base64 42
|
||||
```
|
||||
3.5 If you use annaconda, create a new environment and activate it
|
||||
```bash
|
||||
conda create --name dify python=3.10
|
||||
conda activate dify
|
||||
```
|
||||
4. Install dependencies
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
|
||||
26
api/app.py
26
api/app.py
@@ -1,5 +1,9 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
@@ -12,20 +16,20 @@ from flask import Flask, request, Response, session
|
||||
import flask_login
|
||||
from flask_cors import CORS
|
||||
|
||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \
|
||||
ext_database, ext_storage
|
||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from models import model, account, dataset, web, task
|
||||
from models import model, account, dataset, web, task, source, tool
|
||||
from events import event_handlers
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
import core
|
||||
from config import Config, CloudEditionConfig
|
||||
from commands import register_commands
|
||||
from models.account import TenantAccountJoin
|
||||
from models.account import TenantAccountJoin, AccountStatus
|
||||
from models.model import Account, EndUser, App
|
||||
|
||||
import warnings
|
||||
@@ -77,11 +81,11 @@ def initialize_extensions(app):
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
ext_redis.init_app(app)
|
||||
ext_vector_store.init_app(app)
|
||||
ext_storage.init_app(app)
|
||||
ext_celery.init_app(app)
|
||||
ext_session.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
|
||||
|
||||
@@ -99,6 +103,9 @@ def load_user(user_id):
|
||||
account = db.session.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if account:
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise Forbidden('Account is banned or closed.')
|
||||
|
||||
workspace_id = session.get('workspace_id')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
@@ -122,6 +129,9 @@ def load_user(user_id):
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
session['workspace_id'] = account.current_tenant_id
|
||||
|
||||
account.last_active_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# Log in the user with the updated user_id
|
||||
flask_login.login_user(account, remember=True)
|
||||
|
||||
@@ -145,13 +155,17 @@ def register_blueprints(app):
|
||||
from controllers.web import bp as web_bp
|
||||
from controllers.console import bp as console_app_bp
|
||||
|
||||
CORS(service_api_bp,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(web_bp,
|
||||
resources={
|
||||
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
||||
supports_credentials=True,
|
||||
allow_headers=['Content-Type', 'Authorization'],
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
||||
expose_headers=['X-Version', 'X-Env']
|
||||
)
|
||||
|
||||
157
api/commands.py
157
api/commands.py
@@ -1,17 +1,27 @@
|
||||
import datetime
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
import click
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from libs.password import password_pattern, valid_password, hash_password
|
||||
from libs.helper import email as email_validate
|
||||
from extensions.ext_database import db
|
||||
from models.account import InvitationCode
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
|
||||
from models.model import Account
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
from models.provider import Provider, ProviderName
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
|
||||
@@ -73,6 +83,31 @@ def reset_email(email, new_email, email_confirm):
|
||||
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
|
||||
|
||||
|
||||
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
|
||||
'After the reset, all LLM credentials will become invalid, '
|
||||
'requiring re-entry.'
|
||||
'Only support SELF_HOSTED mode.')
|
||||
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
|
||||
' this operation cannot be rolled back!', fg='red'))
|
||||
def reset_encrypt_key_pair():
|
||||
if current_app.config['EDITION'] != 'SELF_HOSTED':
|
||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
||||
return
|
||||
|
||||
tenant = db.session.query(Tenant).first()
|
||||
if not tenant:
|
||||
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
|
||||
return
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! '
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
|
||||
|
||||
@click.command('generate-invitation-codes', help='Generate invitation codes.')
|
||||
@click.option('--batch', help='The batch of invitation codes.')
|
||||
@click.option('--count', prompt=True, help='Invitation codes count.')
|
||||
@@ -130,7 +165,127 @@ def generate_upper_string():
|
||||
return result
|
||||
|
||||
|
||||
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
|
||||
def recreate_all_dataset_indexes():
|
||||
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
|
||||
recreate_count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
try:
|
||||
click.echo('Recreating dataset index: {}'.format(dataset.id))
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index and index._is_origin():
|
||||
index.recreate_dataset(dataset)
|
||||
recreate_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
||||
|
||||
|
||||
@click.command('clean-unused-dataset-indexes', help='Clean unused dataset indexes.')
|
||||
def clean_unused_dataset_indexes():
|
||||
click.echo(click.style('Start clean unused dataset indexes.', fg='green'))
|
||||
clean_days = int(current_app.config.get('CLEAN_DAY_SETTING'))
|
||||
start_at = time.perf_counter()
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
dataset_query = db.session.query(DatasetQuery).filter(
|
||||
DatasetQuery.created_at > thirty_days_ago,
|
||||
DatasetQuery.dataset_id == dataset.id
|
||||
).all()
|
||||
if not dataset_query or len(dataset_query) == 0:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.indexing_status == 'completed',
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.updated_at > thirty_days_ago
|
||||
).all()
|
||||
if not documents or len(documents) == 0:
|
||||
try:
|
||||
# remove index
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
vector_index.delete()
|
||||
kw_index.delete()
|
||||
# update document
|
||||
update_params = {
|
||||
Document.enabled: False
|
||||
}
|
||||
|
||||
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
|
||||
db.session.commit()
|
||||
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
|
||||
|
||||
|
||||
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
|
||||
def sync_anthropic_hosted_providers():
|
||||
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||
count = 0
|
||||
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for tenant in tenants:
|
||||
try:
|
||||
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
ProviderName.ANTHROPIC.value,
|
||||
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||
True
|
||||
)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
click.echo(click.style(
|
||||
'Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(generate_invitation_codes)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(recreate_all_dataset_indexes)
|
||||
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||
app.cli.add_command(clean_unused_dataset_indexes)
|
||||
|
||||
@@ -28,9 +28,11 @@ DEFAULTS = {
|
||||
'SESSION_REDIS_USE_SSL': 'False',
|
||||
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
|
||||
'OAUTH_REDIRECT_INDEX_PATH': '/',
|
||||
'CONSOLE_URL': 'https://cloud.dify.ai',
|
||||
'API_URL': 'https://api.dify.ai',
|
||||
'APP_URL': 'https://udify.app',
|
||||
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
|
||||
'CONSOLE_API_URL': 'https://cloud.dify.ai',
|
||||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'APP_API_URL': 'https://udify.app',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
@@ -43,11 +45,16 @@ DEFAULTS = {
|
||||
'SENTRY_TRACES_SAMPLE_RATE': 1.0,
|
||||
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
|
||||
'WEAVIATE_GRPC_ENABLED': 'True',
|
||||
'WEAVIATE_BATCH_SIZE': 100,
|
||||
'CELERY_BACKEND': 'database',
|
||||
'PDF_PREVIEW': 'True',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||
'DEFAULT_LLM_PROVIDER': 'openai'
|
||||
'DEFAULT_LLM_PROVIDER': 'openai',
|
||||
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
|
||||
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
|
||||
'TENANT_DOCUMENT_COUNT': 100,
|
||||
'CLEAN_DAY_SETTING': 30
|
||||
}
|
||||
|
||||
|
||||
@@ -75,10 +82,15 @@ class Config:
|
||||
|
||||
def __init__(self):
|
||||
# app settings
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.1"
|
||||
self.CURRENT_VERSION = "0.3.12"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -138,6 +150,7 @@ class Config:
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
|
||||
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
|
||||
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
|
||||
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
@@ -145,10 +158,15 @@ class Config:
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'WEB_API_CORS_ALLOW_ORIGINS', '*')
|
||||
|
||||
# mail settings
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# sentry settings
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
@@ -177,6 +195,10 @@ class Config:
|
||||
|
||||
# hosted provider credentials
|
||||
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
|
||||
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
|
||||
|
||||
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
|
||||
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
|
||||
|
||||
# By default it is False
|
||||
# You could disable it for compatibility with certain OpenAPI providers
|
||||
@@ -186,6 +208,17 @@ class Config:
|
||||
# set default LLM provider, default is 'openai', support `azure_openai`
|
||||
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
|
||||
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -9,16 +9,19 @@ api = ExternalApi(bp)
|
||||
from . import setup, version, apikey, admin
|
||||
|
||||
# Import app controllers
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import workspace, members, providers, account
|
||||
from .workspace import workspace, members, model_providers, account, tool_providers
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message
|
||||
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
|
||||
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import chat, conversation, message, parameter, audio
|
||||
|
||||
@@ -8,6 +8,7 @@ from werkzeug.exceptions import NotFound, Unauthorized
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import supported_language
|
||||
from models.model import RecommendedApp, App, InstalledApp
|
||||
|
||||
|
||||
@@ -47,8 +48,7 @@ class InsertExploreAppListApi(Resource):
|
||||
parser.add_argument('desc', type=str, location='json')
|
||||
parser.add_argument('copyright', type=str, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, location='json')
|
||||
parser.add_argument('language', type=str, required=True, nullable=False, choices=['en-US', 'zh-Hans'],
|
||||
location='json')
|
||||
parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
|
||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('position', type=int, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -9,25 +9,22 @@ from werkzeug.exceptions import Unauthorized, Forbidden
|
||||
|
||||
from constants.model_template import model_templates, demo_model_templates
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError, ProviderQuotaExceededError, \
|
||||
CompletionRequestError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, Site, InstalledApp
|
||||
from services.account_service import TenantService
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||
'pre_prompt': fields.String,
|
||||
@@ -100,7 +97,8 @@ class AppListApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(App.tenant_id == current_user.current_tenant_id).order_by(App.created_at.desc()),
|
||||
db.select(App).where(App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == False).order_by(App.created_at.desc()),
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False)
|
||||
@@ -149,7 +147,9 @@ class AppListApi(Resource):
|
||||
opening_statement=model_configuration['opening_statement'],
|
||||
suggested_questions=json.dumps(model_configuration['suggested_questions']),
|
||||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
pre_prompt=model_configuration['pre_prompt'],
|
||||
@@ -220,7 +220,11 @@ class AppTemplateApi(Resource):
|
||||
account = current_user
|
||||
interface_language = account.interface_language
|
||||
|
||||
return {'data': demo_model_templates.get(interface_language)}
|
||||
templates = demo_model_templates.get(interface_language)
|
||||
if not templates:
|
||||
templates = demo_model_templates.get('en-US')
|
||||
|
||||
return {'data': templates}
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@@ -435,7 +439,9 @@ class AppCopy(Resource):
|
||||
opening_statement=app_config.opening_statement,
|
||||
suggested_questions=app_config.suggested_questions,
|
||||
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
|
||||
speech_to_text=app_config.speech_to_text,
|
||||
more_like_this=app_config.more_like_this,
|
||||
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
|
||||
model=app_config.model,
|
||||
user_input_form=app_config.user_input_form,
|
||||
pre_prompt=app_config.pre_prompt,
|
||||
@@ -478,35 +484,6 @@ class AppExport(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class IntroductionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('prompt_template', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
|
||||
try:
|
||||
answer = LLMGenerator.generate_introduction(
|
||||
account.current_tenant_id,
|
||||
args['prompt_template']
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
|
||||
return {'introduction': answer}
|
||||
|
||||
|
||||
api.add_resource(AppListApi, '/apps')
|
||||
api.add_resource(AppTemplateApi, '/app-templates')
|
||||
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
||||
@@ -515,4 +492,3 @@ api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
|
||||
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
|
||||
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
|
||||
api.add_resource(AppRateLimit, '/apps/<uuid:app_id>/rate-limit')
|
||||
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
|
||||
|
||||
69
api/controllers/console/app/audio.py
Normal file
69
api/controllers/console/app/audio.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import AppUnavailableError, \
|
||||
ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \
|
||||
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from flask_restful import Resource
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
|
||||
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, 'chat')
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -95,6 +95,7 @@ class CompletionConversationApi(Resource):
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String(attribute='end_user.session_id'),
|
||||
'from_account_id': fields.String,
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
@@ -135,6 +136,8 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion')
|
||||
|
||||
query = query.options(joinedload(Conversation.end_user))
|
||||
|
||||
if args['keyword']:
|
||||
query = query.join(
|
||||
Message, Message.conversation_id == Conversation.id
|
||||
@@ -160,7 +163,7 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
@@ -209,6 +212,26 @@ class CompletionConversationDetailApi(Resource):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'completion')
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, conversation_id):
|
||||
app_id = str(app_id)
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
app = _get_app(app_id, 'chat')
|
||||
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first()
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ChatConversationApi(Resource):
|
||||
@@ -226,6 +249,7 @@ class ChatConversationApi(Resource):
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String(attribute='end_user.session_id'),
|
||||
'from_account_id': fields.String,
|
||||
'summary': fields.String(attribute='summary_or_query'),
|
||||
'read_at': TimestampField,
|
||||
@@ -268,6 +292,8 @@ class ChatConversationApi(Resource):
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat')
|
||||
|
||||
query = query.options(joinedload(Conversation.end_user))
|
||||
|
||||
if args['keyword']:
|
||||
query = query.join(
|
||||
Message, Message.conversation_id == Conversation.id
|
||||
@@ -296,7 +322,7 @@ class ChatConversationApi(Resource):
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
@@ -356,6 +382,27 @@ class ChatConversationDetailApi(Resource):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'chat')
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, conversation_id):
|
||||
app_id = str(app_id)
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
# get app info
|
||||
app = _get_app(app_id, 'chat')
|
||||
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first()
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
|
||||
|
||||
class ProviderQuotaExceededError(BaseHTTPException):
|
||||
error_code = 'provider_quota_exceeded'
|
||||
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
|
||||
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
code = 400
|
||||
|
||||
@@ -49,3 +49,27 @@ class AppMoreLikeThisDisabledError(BaseHTTPException):
|
||||
error_code = 'app_more_like_this_disabled'
|
||||
description = "The 'More like this' feature is disabled. Please refresh your page."
|
||||
code = 403
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = 'no_audio_uploaded'
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = 'audio_too_large'
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_audio_type'
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
75
api/controllers/console/app/generator.py
Normal file
75
api/controllers/console/app/generator.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
|
||||
CompletionRequestError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
|
||||
|
||||
class IntroductionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('prompt_template', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
|
||||
try:
|
||||
answer = LLMGenerator.generate_introduction(
|
||||
account.current_tenant_id,
|
||||
args['prompt_template']
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
|
||||
return {'introduction': answer}
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('audiences', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('hoping_to_solve', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
account.current_tenant_id,
|
||||
args['audiences'],
|
||||
args['hoping_to_solve']
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
|
||||
api.add_resource(RuleGenerateApi, '/rule-generate')
|
||||
@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -41,7 +41,9 @@ class ModelConfigResource(Resource):
|
||||
opening_statement=model_configuration['opening_statement'],
|
||||
suggested_questions=json.dumps(model_configuration['suggested_questions']),
|
||||
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
|
||||
speech_to_text=json.dumps(model_configuration['speech_to_text']),
|
||||
more_like_this=json.dumps(model_configuration['more_like_this']),
|
||||
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
|
||||
model=json.dumps(model_configuration['model']),
|
||||
user_input_form=json.dumps(model_configuration['user_input_form']),
|
||||
pre_prompt=model_configuration['pre_prompt'],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
@@ -59,18 +60,20 @@ class DailyConversationStatistic(Resource):
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
rs = db.session.execute(sql_query, arg_dict)
|
||||
|
||||
response_date = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_date.append({
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'conversation_count': i.conversation_count
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_date
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
@@ -119,18 +122,20 @@ class DailyTerminalsStatistic(Resource):
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
rs = db.session.execute(sql_query, arg_dict)
|
||||
|
||||
response_date = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_date.append({
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'terminal_count': i.terminal_count
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_date
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
@@ -180,12 +185,14 @@ class DailyTokenCostStatistic(Resource):
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
rs = db.session.execute(sql_query, arg_dict)
|
||||
|
||||
response_date = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_date.append({
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'token_count': i.token_count,
|
||||
'total_price': i.total_price,
|
||||
@@ -193,10 +200,207 @@ class DailyTokenCostStatistic(Resource):
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_date
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class AverageSessionInteractionStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
account = current_user
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, 'chat')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
FROM conversations c
|
||||
JOIN messages m ON c.id = m.conversation_id
|
||||
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and c.created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and c.created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += """
|
||||
GROUP BY m.conversation_id) subquery
|
||||
LEFT JOIN conversations c on c.id=subquery.conversation_id
|
||||
GROUP BY date
|
||||
ORDER BY date"""
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class UserSatisfactionRateStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
account = current_user
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
||||
FROM messages m
|
||||
LEFT JOIN message_feedbacks mf on mf.message_id=m.id
|
||||
WHERE m.app_id = :app_id
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and m.created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and m.created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class AverageResponseTimeStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
account = current_user
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, 'completion')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) as latency
|
||||
FROM messages
|
||||
WHERE app_id = :app_id
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
|
||||
response_data = []
|
||||
|
||||
for i in rs:
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'latency': round(i.latency * 1000, 4)
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations')
|
||||
api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users')
|
||||
api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs')
|
||||
api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions')
|
||||
api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate')
|
||||
api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time')
|
||||
|
||||
75
api/controllers/console/auth/activate.py
Normal file
75
api/controllers/console/auth/activate.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, supported_language, timezone
|
||||
from libs.password import valid_password, hash_password
|
||||
from models.account import AccountStatus, Tenant
|
||||
from services.account_service import RegisterService
|
||||
|
||||
|
||||
class ActivateCheckApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='args')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == args['workspace_id'],
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
|
||||
return {'is_valid': account is not None, 'workspace_name': tenant.name}
|
||||
|
||||
|
||||
class ActivateApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('email', type=email, required=True, nullable=False, location='json')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
if account is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
account.name = args['name']
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(args['password'], salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
account.interface_language = args['interface_language']
|
||||
account.timezone = args['timezone']
|
||||
account.interface_theme = 'light'
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ActivateCheckApi, '/activate/check')
|
||||
api.add_resource(ActivateApi, '/activate')
|
||||
101
api/controllers/console/auth/data_source_oauth.py
Normal file
101
api/controllers/console/auth/data_source_oauth.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask_login import current_user, login_required
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from controllers.console import api
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'NOTION_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
class OAuthDataSource(Resource):
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
|
||||
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
|
||||
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, binding_id):
|
||||
provider = str(provider)
|
||||
binding_id = str(binding_id)
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
|
||||
@@ -20,13 +20,13 @@ def get_oauth_providers():
|
||||
client_secret=current_app.config.get(
|
||||
'GITHUB_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/github')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
|
||||
|
||||
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'GOOGLE_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_URL') + '/console/api/oauth/authorize/google')
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'github': github_oauth,
|
||||
@@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
||||
flask_login.login_user(account, remember=True)
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
||||
304
api/controllers/console/datasets/data_source.py
Normal file
304
api/controllers/console/datasets/data_source.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields)),
|
||||
'total': fields.Integer
|
||||
}
|
||||
integrate_fields = {
|
||||
'id': fields.String,
|
||||
'provider': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'is_bound': fields.Boolean,
|
||||
'disabled': fields.Boolean,
|
||||
'link': fields.String,
|
||||
'source_info': fields.Nested(integrate_workspace_fields)
|
||||
}
|
||||
integrate_list_fields = {
|
||||
'data': fields.List(fields.Nested(integrate_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.query(DataSourceBinding).filter(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.disabled == False
|
||||
).all()
|
||||
|
||||
base_url = request.url_root.rstrip('/')
|
||||
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
||||
providers = ["notion"]
|
||||
|
||||
integrate_data = []
|
||||
for provider in providers:
|
||||
# existing_integrate = next((ai for ai in data_source_integrates if ai.provider == provider), None)
|
||||
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
|
||||
if existing_integrates:
|
||||
for existing_integrate in list(existing_integrates):
|
||||
integrate_data.append({
|
||||
'id': existing_integrate.id,
|
||||
'provider': provider,
|
||||
'created_at': existing_integrate.created_at,
|
||||
'is_bound': True,
|
||||
'disabled': existing_integrate.disabled,
|
||||
'source_info': existing_integrate.source_info,
|
||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
||||
})
|
||||
else:
|
||||
integrate_data.append({
|
||||
'id': None,
|
||||
'provider': provider,
|
||||
'created_at': None,
|
||||
'source_info': None,
|
||||
'is_bound': False,
|
||||
'disabled': None,
|
||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
||||
})
|
||||
return {'data': integrate_data}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
data_source_binding = DataSourceBinding.query.filter_by(
|
||||
id=binding_id
|
||||
).first()
|
||||
if data_source_binding is None:
|
||||
raise NotFound('Data source binding not found.')
|
||||
# enable binding
|
||||
if action == 'enable':
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError('Data source is not disabled.')
|
||||
# disable binding
|
||||
if action == 'disable':
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError('Data source is disabled.')
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class DataSourceNotionListApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'is_bound': fields.Boolean,
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields))
|
||||
}
|
||||
integrate_notion_info_list_fields = {
|
||||
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
dataset_id = request.args.get('dataset_id', default=None, type=str)
|
||||
exist_page_ids = []
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
if dataset.data_source_type != 'notion_import':
|
||||
raise ValueError('Dataset is not notion type.')
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type='notion_import',
|
||||
enabled=True
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceBinding.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider='notion',
|
||||
disabled=False
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {
|
||||
'notion_info': []
|
||||
}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info['pages']
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page['page_id'] in exist_page_ids:
|
||||
page['is_bound'] = True
|
||||
else:
|
||||
page['is_bound'] = False
|
||||
pre_import_info = {
|
||||
'workspace_name': source_info['workspace_name'],
|
||||
'workspace_icon': source_info['workspace_icon'],
|
||||
'workspace_id': source_info['workspace_id'],
|
||||
'pages': pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {
|
||||
'notion_info': pre_import_info_list
|
||||
}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise NotFound('Data source binding not found.')
|
||||
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type
|
||||
)
|
||||
|
||||
text_docs = loader.load()
|
||||
return {
|
||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
|
||||
return response, 200
|
||||
|
||||
|
||||
class DataSourceNotionDatasetSyncApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
|
||||
for document in documents:
|
||||
document_indexing_sync_task.delay(dataset_id_str, document.id)
|
||||
return 200
|
||||
|
||||
|
||||
class DataSourceNotionDocumentSyncApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if document is None:
|
||||
raise NotFound("Document not found.")
|
||||
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
|
||||
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
|
||||
api.add_resource(DataSourceNotionApi,
|
||||
'/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
|
||||
'/datasets/notion-indexing-estimate')
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
|
||||
api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')
|
||||
@@ -3,7 +3,6 @@ from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
@@ -12,8 +11,9 @@ from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
@@ -50,8 +50,8 @@ def _validate_name(name):
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 200:
|
||||
raise ValueError('Description cannot exceed 200 characters.')
|
||||
if len(description) > 400:
|
||||
raise ValueError('Description cannot exceed 400 characters.')
|
||||
return description
|
||||
|
||||
|
||||
@@ -217,17 +217,32 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
segment_rule = request.get_json()
|
||||
file_detail = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id == segment_rule["file_id"]
|
||||
).first()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(file_ids)
|
||||
).all()
|
||||
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule'])
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'])
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response, 200
|
||||
|
||||
|
||||
@@ -274,8 +289,54 @@ class DatasetRelatedAppListApi(Resource):
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetIndexingStatusApi(Resource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.tenant_id == current_user.current_tenant_id
|
||||
).all()
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
@@ -59,6 +60,31 @@ document_fields = {
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'doc_form': fields.String,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer
|
||||
}
|
||||
|
||||
|
||||
@@ -83,6 +109,23 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
||||
|
||||
if not documents:
|
||||
raise NotFound('Documents not found.')
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class GetProcessRuleApi(Resource):
|
||||
@setup_required
|
||||
@@ -132,9 +175,9 @@ class DatasetDocumentListApi(Resource):
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
search = request.args.get('search', default=None, type=str)
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
sort = request.args.get('sort', default='-created_at', type=str)
|
||||
|
||||
fetch = request.args.get('fetch', default=False, type=bool)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
@@ -173,9 +216,20 @@ class DatasetDocumentListApi(Resource):
|
||||
paginated_documents = query.paginate(
|
||||
page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
data = marshal(documents, document_with_segments_fields)
|
||||
else:
|
||||
data = marshal(documents, document_fields)
|
||||
response = {
|
||||
'data': marshal(documents, document_fields),
|
||||
'data': data,
|
||||
'has_more': len(documents) == limit,
|
||||
'limit': limit,
|
||||
'total': paginated_documents.total,
|
||||
@@ -184,10 +238,15 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
return response
|
||||
|
||||
documents_and_batch_fields = {
|
||||
'documents': fields.List(fields.Nested(document_fields)),
|
||||
'batch': fields.String
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
@marshal_with(documents_and_batch_fields)
|
||||
def post(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
@@ -208,9 +267,11 @@ class DatasetDocumentListApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('data_source', type=dict, required=False, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
@@ -220,21 +281,25 @@ class DatasetDocumentListApi(Resource):
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
document = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return document
|
||||
return {
|
||||
'documents': documents,
|
||||
'batch': batch
|
||||
}
|
||||
|
||||
|
||||
class DatasetInitApi(Resource):
|
||||
dataset_and_document_fields = {
|
||||
'dataset': fields.Nested(dataset_fields),
|
||||
'document': fields.Nested(document_fields)
|
||||
'documents': fields.List(fields.Nested(document_fields)),
|
||||
'batch': fields.String
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@@ -251,19 +316,20 @@ class DatasetInitApi(Resource):
|
||||
nullable=False, location='json')
|
||||
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
dataset, document = DocumentService.save_document_without_dataset_id(
|
||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
document_data=args,
|
||||
account=current_user
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -271,7 +337,8 @@ class DatasetInitApi(Resource):
|
||||
|
||||
response = {
|
||||
'dataset': dataset,
|
||||
'document': document
|
||||
'documents': documents,
|
||||
'batch': batch
|
||||
}
|
||||
|
||||
return response
|
||||
@@ -316,11 +383,124 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound('File not found.')
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(file, data_process_rule_dict)
|
||||
|
||||
response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
response = {
|
||||
"tokens": 0,
|
||||
"total_price": 0,
|
||||
"currency": "USD",
|
||||
"total_segments": 0,
|
||||
"preview": []
|
||||
}
|
||||
if not documents:
|
||||
return response
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
# format document files info
|
||||
if data_source_info and 'upload_file_id' in data_source_info:
|
||||
file_id = data_source_info['upload_file_id']
|
||||
info_list.append(file_id)
|
||||
# format document notion info
|
||||
elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info:
|
||||
pages = []
|
||||
page = {
|
||||
'page_id': data_source_info['notion_page_id'],
|
||||
'type': data_source_info['type']
|
||||
}
|
||||
pages.append(page)
|
||||
notion_info = {
|
||||
'workspace_id': data_source_info['notion_workspace_id'],
|
||||
'pages': pages
|
||||
}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if dataset.data_source_type == 'upload_file':
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id in info_list
|
||||
).all()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict)
|
||||
elif dataset.data_source_type:
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(info_list,
|
||||
data_process_rule_dict)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
|
||||
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class DocumentIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
@@ -347,10 +527,12 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
|
||||
completed_segments = DocumentSegment.query \
|
||||
.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id)) \
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != 're_segment') \
|
||||
.count()
|
||||
total_segments = DocumentSegment.query \
|
||||
.filter_by(document_id=str(document_id)) \
|
||||
.filter(DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != 're_segment') \
|
||||
.count()
|
||||
|
||||
document.completed_segments = completed_segments
|
||||
@@ -405,9 +587,10 @@ class DocumentDetailApi(DocumentResource):
|
||||
'disabled_by': document.disabled_by,
|
||||
'archived': document.archived,
|
||||
'segment_count': document.segment_count,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'hit_count': document.hit_count,
|
||||
'display_status': document.display_status
|
||||
'display_status': document.display_status,
|
||||
'doc_form': document.doc_form
|
||||
}
|
||||
else:
|
||||
process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@@ -425,7 +608,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
'created_at': document.created_at.timestamp(),
|
||||
'tokens': document.tokens,
|
||||
'indexing_status': document.indexing_status,
|
||||
'completed_at': int(document.completed_at.timestamp())if document.completed_at else None,
|
||||
'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
'indexing_latency': document.indexing_latency,
|
||||
'error': document.error,
|
||||
@@ -438,7 +621,8 @@ class DocumentDetailApi(DocumentResource):
|
||||
'segment_count': document.segment_count,
|
||||
'average_segment_length': document.average_segment_length,
|
||||
'hit_count': document.hit_count,
|
||||
'display_status': document.display_status
|
||||
'display_status': document.display_status,
|
||||
'doc_form': document.doc_form
|
||||
}
|
||||
|
||||
return response, 200
|
||||
@@ -576,6 +760,8 @@ class DocumentStatusApi(DocumentResource):
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
elif action == "disable":
|
||||
if not document.completed_at or document.indexing_status != 'completed':
|
||||
raise InvalidActionError('Document is not completed.')
|
||||
if not document.enabled:
|
||||
raise InvalidActionError('Document already disabled.')
|
||||
|
||||
@@ -675,6 +861,10 @@ api.add_resource(DatasetInitApi,
|
||||
'/datasets/init')
|
||||
api.add_resource(DocumentIndexingEstimateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate')
|
||||
api.add_resource(DocumentBatchIndexingEstimateApi,
|
||||
'/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate')
|
||||
api.add_resource(DocumentBatchIndexingStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status')
|
||||
api.add_resource(DocumentIndexingStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status')
|
||||
api.add_resource(DocumentDetailApi,
|
||||
|
||||
@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.add_segment_to_index_task import add_segment_to_index_task
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
|
||||
from tasks.remove_segment_from_index_task import remove_segment_from_index_task
|
||||
|
||||
segment_fields = {
|
||||
@@ -24,6 +24,7 @@ segment_fields = {
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
@@ -78,12 +79,14 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
parser.add_argument('hit_count_gte', type=int,
|
||||
default=None, location='args')
|
||||
parser.add_argument('enabled', type=str, default='all', location='args')
|
||||
parser.add_argument('keyword', type=str, default=None, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
last_id = args['last_id']
|
||||
limit = min(args['limit'], 100)
|
||||
status_list = args['status']
|
||||
hit_count_gte = args['hit_count_gte']
|
||||
keyword = args['keyword']
|
||||
|
||||
query = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
@@ -104,6 +107,9 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if hit_count_gte is not None:
|
||||
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
|
||||
|
||||
if args['enabled'].lower() != 'all':
|
||||
if args['enabled'].lower() == 'true':
|
||||
query = query.filter(DocumentSegment.enabled == True)
|
||||
@@ -120,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form,
|
||||
'has_more': has_more,
|
||||
'limit': limit,
|
||||
'total': total
|
||||
@@ -175,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
# Set cache to prevent indexing the same segment multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_segment_to_index_task.delay(segment.id)
|
||||
enable_segment_to_index_task.delay(segment.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
elif action == "disable":
|
||||
@@ -197,7 +204,91 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
raise InvalidActionError()
|
||||
|
||||
|
||||
class DatasetDocumentSegmentAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetDocumentSegmentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
api.add_resource(DatasetDocumentSegmentApi,
|
||||
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
|
||||
api.add_resource(DatasetDocumentSegmentAddApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
|
||||
api.add_resource(DatasetDocumentSegmentUpdateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import tempfile
|
||||
import chardet
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -16,8 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
|
||||
UnsupportedFileTypeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.index.readers.html_parser import HTMLParser
|
||||
from core.index.readers.pdf_parser import PDFParser
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from extensions.ext_storage import storage
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
@@ -26,7 +26,7 @@ from models.model import UploadFile
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm']
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
@@ -121,24 +121,7 @@ class FilePreviewApi(Resource):
|
||||
if extension not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, filepath)
|
||||
|
||||
if extension == 'pdf':
|
||||
parser = PDFParser({'upload_file': upload_file})
|
||||
text = parser.parse_file(Path(filepath))
|
||||
elif extension in ['html', 'htm']:
|
||||
# Use BeautifulSoup to extract text
|
||||
parser = HTMLParser()
|
||||
text = parser.parse_file(Path(filepath))
|
||||
else:
|
||||
# ['txt', 'markdown', 'md']
|
||||
with open(filepath, "rb") as fp:
|
||||
data = fp.read()
|
||||
text = data.decode(encoding='utf-8').strip() if data else ''
|
||||
|
||||
text = FileExtractor.load(upload_file, return_text=True)
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
return {'content': text}
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ segment_fields = {
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
@@ -95,8 +96,8 @@ class HitTestingApi(Resource):
|
||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -18,3 +18,9 @@ class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = 'account_not_link_tenant'
|
||||
description = "Account not link tenant."
|
||||
code = 403
|
||||
|
||||
|
||||
class AlreadyActivateError(BaseHTTPException):
|
||||
error_code = 'already_activate'
|
||||
description = "Auth Token is invalid or account already activated, please check again."
|
||||
code = 403
|
||||
|
||||
66
api/controllers/console/explore/audio.py
Normal file
66
api/controllers/console/explore/audio.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
|
||||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -65,7 +65,10 @@ class ConversationApi(InstalledAppResource):
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import InstalledApp
|
||||
|
||||
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -21,20 +25,23 @@ class AppParameterApi(InstalledAppResource):
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, installed_app):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
66
api/controllers/console/universal_chat/audio.py
Normal file
66
api/controllers/console/universal_chat/audio.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
|
||||
NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class UniversalChatAudioApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text')
|
||||
142
api/controllers/console/universal_chat/chat.py
Normal file
142
api/controllers/console/universal_chat/chat.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.constant import llm_constant
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
||||
class UniversalChatApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
parser.add_argument('tools', type=list, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
# update app model config
|
||||
args['model_config'] = app_model_config.to_dict()
|
||||
args['model_config']['model']['name'] = args['model']
|
||||
|
||||
if not llm_constant.models[args['model']]:
|
||||
raise ValueError("Model not exists.")
|
||||
|
||||
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
|
||||
args['model_config']['agent_mode']['tools'] = args['tools']
|
||||
|
||||
if not args['model_config']['agent_mode']['tools']:
|
||||
args['model_config']['agent_mode']['tools'] = [
|
||||
{
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
]
|
||||
else:
|
||||
args['model_config']['agent_mode']['tools'].append({
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
})
|
||||
|
||||
args['inputs'] = {}
|
||||
|
||||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
from_source='console',
|
||||
streaming=True,
|
||||
is_model_config_override=True,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class UniversalChatStopApi(UniversalChatResource):
|
||||
def post(self, universal_app, task_id):
|
||||
PubHandler.stop(current_user, task_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
|
||||
api.add_resource(UniversalChatApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')
|
||||
118
api/controllers/console/universal_chat/conversation.py
Normal file
118
api/controllers/console/universal_chat/conversation.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, reqparse, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Raw,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class UniversalChatConversationListApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if 'pinned' in args and args['pinned'] is not None:
|
||||
pinned = True if args['pinned'] == 'true' else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationApi(UniversalChatResource):
|
||||
def delete(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
def post(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationPinApi(UniversalChatResource):
|
||||
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class UniversalChatConversationUnPinApi(UniversalChatResource):
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
|
||||
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
|
||||
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
|
||||
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
|
||||
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')
|
||||
127
api/controllers/console/universal_chat/message.py
Normal file
127
api/controllers/console/universal_chat/message.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, fields, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound, InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class UniversalChatMessageListApi(UniversalChatResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
'id': fields.String,
|
||||
'chain_id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatMessageFeedbackApi(UniversalChatResource):
|
||||
def post(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||
def get(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {'data': questions}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')
|
||||
36
api/controllers/console/universal_chat/parameter.py
Normal file
36
api/controllers/console/universal_chat/parameter.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class UniversalChatParameterApi(UniversalChatResource):
|
||||
"""Resource for app variables."""
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, universal_app: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = universal_app
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')
|
||||
84
api/controllers/console/universal_chat/wraps.py
Normal file
84
api/controllers/console/universal_chat/wraps.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
def universal_chat_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# get universal chat app
|
||||
universal_app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == True
|
||||
).first()
|
||||
|
||||
if universal_app is None:
|
||||
# create universal app if not exists
|
||||
universal_app = App(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name='Universal Chat',
|
||||
mode='chat',
|
||||
is_universal=True,
|
||||
icon='',
|
||||
icon_background='',
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
status='normal'
|
||||
)
|
||||
|
||||
db.session.add(universal_app)
|
||||
db.session.flush()
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
provider="",
|
||||
model_id="",
|
||||
configs={},
|
||||
opening_statement='',
|
||||
suggested_questions=json.dumps([]),
|
||||
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||
speech_to_text=json.dumps({'enabled': True}),
|
||||
more_like_this=None,
|
||||
sensitive_word_avoidance=None,
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-16k",
|
||||
"completion_params": {
|
||||
"max_tokens": 800,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([]),
|
||||
pre_prompt='',
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
|
||||
)
|
||||
|
||||
app_model_config.app_id = universal_app.id
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
universal_app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
return view(universal_app, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
class UniversalChatResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]
|
||||
@@ -32,8 +32,13 @@ class VersionApi(Resource):
|
||||
'current_version': args.get('current_version')
|
||||
})
|
||||
except Exception as error:
|
||||
logging.exception("Check update error.")
|
||||
raise InternalServerError()
|
||||
logging.warning("Check update version error: {}.".format(str(error)))
|
||||
return {
|
||||
'version': args.get('current_version'),
|
||||
'release_date': '',
|
||||
'release_notes': '',
|
||||
'can_auto_update': False
|
||||
}
|
||||
|
||||
content = json.loads(response.content)
|
||||
return {
|
||||
|
||||
@@ -6,22 +6,23 @@ from flask import current_app, request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.workspace.error import AccountAlreadyInitedError, InvalidInvitationCodeError, \
|
||||
RepeatPasswordNotMatchError
|
||||
RepeatPasswordNotMatchError, CurrentPasswordIncorrectError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import TimestampField, supported_language, timezone
|
||||
from extensions.ext_database import db
|
||||
from models.account import InvitationCode, AccountIntegrate
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'avatar': fields.String,
|
||||
'email': fields.String,
|
||||
'is_password_set': fields.Boolean,
|
||||
'interface_language': fields.String,
|
||||
'interface_theme': fields.String,
|
||||
'timezone': fields.String,
|
||||
@@ -194,8 +195,11 @@ class AccountPasswordApi(Resource):
|
||||
if args['new_password'] != args['repeat_new_password']:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
|
||||
AccountService.update_account_password(
|
||||
current_user, args['password'], args['new_password'])
|
||||
try:
|
||||
AccountService.update_account_password(
|
||||
current_user, args['password'], args['new_password'])
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -7,6 +7,12 @@ class RepeatPasswordNotMatchError(BaseHTTPException):
|
||||
code = 400
|
||||
|
||||
|
||||
class CurrentPasswordIncorrectError(BaseHTTPException):
|
||||
error_code = 'current_password_incorrect'
|
||||
description = "Current password is incorrect."
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderRequestFailedError(BaseHTTPException):
|
||||
error_code = 'provider_request_failed'
|
||||
description = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
@@ -60,7 +60,8 @@ class MemberInviteEmailApi(Resource):
|
||||
inviter = current_user
|
||||
|
||||
try:
|
||||
RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, inviter=inviter)
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == args['email']).first()
|
||||
@@ -78,7 +79,16 @@ class MemberInviteEmailApi(Resource):
|
||||
|
||||
# todo:413
|
||||
|
||||
return {'result': 'success', 'account': account}, 201
|
||||
return {
|
||||
'result': 'success',
|
||||
'account': account,
|
||||
'invite_url': '{}/activate?workspace_id={}&email={}&token={}'.format(
|
||||
current_app.config.get("CONSOLE_WEB_URL"),
|
||||
str(current_user.current_tenant_id),
|
||||
invitee_email,
|
||||
token
|
||||
)
|
||||
}, 201
|
||||
|
||||
|
||||
class MemberCancelInviteApi(Resource):
|
||||
@@ -88,7 +98,7 @@ class MemberCancelInviteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id):
|
||||
member = Account.query.get(str(member_id))
|
||||
member = db.session.query(Account).filter(Account.id == str(member_id)).first()
|
||||
if not member:
|
||||
abort(404)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, abort
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
|
||||
ProviderService.init_supported_provider(current_user.current_tenant)
|
||||
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
|
||||
|
||||
provider_list = [
|
||||
@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
|
||||
'quota_used': p.quota_used
|
||||
} if p.provider_type == ProviderType.SYSTEM.value else {}),
|
||||
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
|
||||
ProviderName(p.provider_name))
|
||||
ProviderName(p.provider_name), only_custom=True)
|
||||
if p.provider_type == ProviderType.CUSTOM.value else None
|
||||
}
|
||||
for p in providers
|
||||
]
|
||||
@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
|
||||
is_valid=token_is_valid)
|
||||
db.session.add(provider_model)
|
||||
|
||||
if provider_model.is_valid:
|
||||
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
|
||||
other_providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant.id,
|
||||
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
|
||||
Provider.provider_name != provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
|
||||
|
||||
@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: remove this when the provider is supported
|
||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
|
||||
if provider in [ProviderName.COHERE.value,
|
||||
ProviderName.HUGGINGFACEHUB.value]:
|
||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
|
||||
|
||||
@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
|
||||
provider_model.is_valid = args['is_enabled']
|
||||
db.session.commit()
|
||||
elif not provider_model:
|
||||
ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
|
||||
else:
|
||||
quota_limit = 0
|
||||
|
||||
ProviderService.create_system_provider(
|
||||
tenant,
|
||||
provider,
|
||||
quota_limit,
|
||||
args['is_enabled']
|
||||
)
|
||||
else:
|
||||
abort(403)
|
||||
|
||||
136
api/controllers/console/workspace/tool_providers.py
Normal file
136
api/controllers/console/workspace/tool_providers.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.provider.tool_provider_service import ToolProviderService
|
||||
from extensions.ext_database import db
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_credential_dict = {}
|
||||
for tool_name in ToolProviderName:
|
||||
tool_credential_dict[tool_name.value] = {
|
||||
'tool_name': tool_name.value,
|
||||
'is_enabled': False,
|
||||
'credentials': None
|
||||
}
|
||||
|
||||
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
for p in tool_providers:
|
||||
if p.is_enabled:
|
||||
tool_credential_dict[p.tool_name] = {
|
||||
'tool_name': p.tool_name,
|
||||
'is_enabled': p.is_enabled,
|
||||
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
|
||||
}
|
||||
|
||||
return list(tool_credential_dict.values())
|
||||
|
||||
|
||||
class ToolProviderCredentialsApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
|
||||
f'current_role is {current_user.current_tenant.current_role}')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
tool_provider_model = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == tenant.id,
|
||||
ToolProvider.tool_name == provider,
|
||||
).first()
|
||||
|
||||
# Only allow updating token for CUSTOM provider type
|
||||
if tool_provider_model:
|
||||
tool_provider_model.encrypted_credentials = encrypted_credentials
|
||||
tool_provider_model.is_enabled = True
|
||||
else:
|
||||
tool_provider_model = ToolProvider(
|
||||
tenant_id=tenant.id,
|
||||
tool_name=provider,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
is_enabled=True
|
||||
)
|
||||
db.session.add(tool_provider_model)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ToolProviderCredentialsValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
||||
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
|
||||
api.add_resource(ToolProviderCredentialsValidateApi,
|
||||
'/workspaces/current/tool-providers/<provider>/credentials-validate')
|
||||
@@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from .app import completion, app, conversation, message
|
||||
from .app import completion, app, conversation, message, audio
|
||||
|
||||
from .dataset import document
|
||||
|
||||
@@ -4,6 +4,10 @@ from flask_restful import fields, marshal_with
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -22,19 +26,22 @@ class AppParameterApi(AppApiResource):
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
61
api/controllers/service_api/app/audio.py
Normal file
61
api/controllers/service_api/app/audio.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \
|
||||
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
|
||||
ProviderNotSupportSpeechToTextError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from models.model import App, AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
|
||||
class AudioApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -48,6 +49,24 @@ class ConversationApi(AppApiResource):
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
@marshal_with(conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
user = request.get_json().get('user')
|
||||
|
||||
if end_user is None and user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
return {"result": "success"}
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
|
||||
@@ -74,3 +93,5 @@ class ConversationRenameApi(AppApiResource):
|
||||
|
||||
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name')
|
||||
api.add_resource(ConversationApi, '/conversations')
|
||||
api.add_resource(ConversationApi, '/conversations/<uuid:c_id>', endpoint='conversation')
|
||||
api.add_resource(ConversationDetailApi, '/conversations/<uuid:c_id>', endpoint='conversation_detail')
|
||||
|
||||
@@ -51,3 +51,27 @@ class CompletionRequestError(BaseHTTPException):
|
||||
description = "Completion request failed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = 'no_audio_uploaded'
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = 'audio_too_large'
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_audio_type'
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
@@ -69,21 +69,25 @@ class DocumentListApi(DatasetApiResource):
|
||||
document_data = {
|
||||
'data_source': {
|
||||
'type': 'upload_file',
|
||||
'info': upload_file.id
|
||||
'info': [
|
||||
{
|
||||
'upload_file_id': upload_file.id
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
document = DocumentService.save_document_with_dataset_id(
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=document_data,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
if doc_type and doc_metadata:
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
|
||||
@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import completion, app, conversation, message, site, saved_message
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport
|
||||
|
||||
@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.provider import ProviderName
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -21,19 +25,22 @@ class AppParameterApi(WebApiResource):
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
}
|
||||
|
||||
63
api/controllers/web/audio.py
Normal file
63
api/controllers/web/audio.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, \
|
||||
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
class AudioApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
@@ -62,7 +62,10 @@ class ConversationApi(WebApiResource):
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -62,3 +62,27 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
||||
error_code = 'app_suggested_questions_after_answer_disabled'
|
||||
description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page."
|
||||
code = 403
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = 'no_audio_uploaded'
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = 'audio_too_large'
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_audio_type'
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
raise NotFound("Message Not Exists.")
|
||||
except MoreLikeThisDisabledError:
|
||||
raise AppMoreLikeThisDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
|
||||
64
api/controllers/web/passport.py
Normal file
64
api/controllers/web/passport.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from controllers.web import api
|
||||
from flask_restful import Resource
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Unauthorized, NotFound
|
||||
from models.model import Site, EndUser, App
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
|
||||
class PassportResource(Resource):
|
||||
"""Base resource for passport."""
|
||||
def get(self):
|
||||
app_id = request.headers.get('X-App-Code')
|
||||
if app_id is None:
|
||||
raise Unauthorized('X-App-Code header is missing.')
|
||||
|
||||
# get site from db and check if it is normal
|
||||
site = db.session.query(Site).filter(
|
||||
Site.code == app_id,
|
||||
Site.status == 'normal'
|
||||
).first()
|
||||
if not site:
|
||||
raise NotFound()
|
||||
# get app from db and check if it is normal and enable_site
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='browser',
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
payload = {
|
||||
"iss": site.app_id,
|
||||
'sub': 'Web API Passport',
|
||||
'app_id': site.app_id,
|
||||
'end_user_id': end_user.id,
|
||||
}
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
'access_token': tk,
|
||||
}
|
||||
|
||||
api.add_resource(PassportResource, '/passport')
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
"""
|
||||
while True:
|
||||
session_id = str(uuid.uuid4())
|
||||
existing_count = db.session.query(EndUser) \
|
||||
.filter(EndUser.session_id == session_id).count()
|
||||
if existing_count == 0:
|
||||
return session_id
|
||||
@@ -1,110 +1,50 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from flask import request, session
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Site, EndUser
|
||||
from models.model import App, EndUser
|
||||
from libs.passport import PassportService
|
||||
|
||||
|
||||
def validate_token(view=None):
|
||||
def validate_jwt_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
site = validate_and_get_site()
|
||||
|
||||
app_model = db.session.query(App).get(site.app_id)
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
|
||||
if app_model.status != 'normal':
|
||||
raise NotFound()
|
||||
|
||||
if not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
end_user = create_or_update_end_user_for_session(app_model)
|
||||
app_model, end_user = decode_jwt_token()
|
||||
|
||||
return view(app_model, end_user, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_site():
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
def decode_jwt_token():
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header is None:
|
||||
raise Unauthorized('Authorization header is missing.')
|
||||
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
|
||||
auth_scheme, tk = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
site = db.session.query(Site).filter(
|
||||
Site.code == auth_token,
|
||||
Site.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not site:
|
||||
decoded = PassportService().verify(tk)
|
||||
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if app_model.enable_site is False:
|
||||
raise Unauthorized('Site is disabled.')
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
def create_or_update_end_user_for_session(app_model):
|
||||
"""
|
||||
Create or update session terminal based on session ID.
|
||||
"""
|
||||
if 'session_id' not in session:
|
||||
session['session_id'] = generate_session_id()
|
||||
|
||||
session_id = session.get('session_id')
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.type == 'browser'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='browser',
|
||||
is_anonymous=True,
|
||||
session_id=session_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
"""
|
||||
count = 1
|
||||
session_id = ''
|
||||
while count != 0:
|
||||
session_id = str(uuid.uuid4())
|
||||
count = db.session.query(EndUser) \
|
||||
.filter(EndUser.session_id == session_id).count()
|
||||
|
||||
return session_id
|
||||
|
||||
return app_model, end_user
|
||||
|
||||
class WebApiResource(Resource):
|
||||
method_decorators = [validate_token]
|
||||
method_decorators = [validate_jwt_token]
|
||||
|
||||
@@ -3,50 +3,34 @@ from typing import Optional
|
||||
|
||||
import langchain
|
||||
from flask import Flask
|
||||
from jieba.analyse import default_tfidf
|
||||
from langchain import set_handler
|
||||
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
|
||||
from llama_index import IndexStructType, QueryMode
|
||||
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
|
||||
from core.index.keyword_table.stopwords import STOPWORDS
|
||||
from core.prompt.prompt_template import OneLineFormatter
|
||||
from core.vector_store.vector_store import VectorStore
|
||||
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
|
||||
|
||||
|
||||
class HostedOpenAICredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedAnthropicCredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedLLMCredentials(BaseModel):
|
||||
openai: Optional[HostedOpenAICredential] = None
|
||||
anthropic: Optional[HostedAnthropicCredential] = None
|
||||
|
||||
|
||||
hosted_llm_credentials = HostedLLMCredentials()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
formatter = OneLineFormatter()
|
||||
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
|
||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
|
||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
|
||||
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
|
||||
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
|
||||
}
|
||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
|
||||
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
|
||||
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
|
||||
}
|
||||
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
langchain.verbose = True
|
||||
set_handler(DifyStdOutCallbackHandler())
|
||||
|
||||
if app.config.get("OPENAI_API_KEY"):
|
||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||
|
||||
if app.config.get("ANTHROPIC_API_KEY"):
|
||||
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
|
||||
|
||||
35
api/core/agent/agent/calc_token_mixin.py
Normal file
35
api/core/agent/agent/calc_token_mixin.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import cast, List
|
||||
|
||||
from langchain import OpenAI
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
return llm.get_num_tokens_from_messages(messages)
|
||||
|
||||
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param llm:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
completion_max_tokens = llm.max_tokens
|
||||
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
|
||||
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
class ExceededLLMTokensLimitError(Exception):
|
||||
pass
|
||||
83
api/core/agent/agent/multi_dataset_router_agent.py
Normal file
83
api/core/agent/agent/multi_dataset_router_agent.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
return super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
**kwargs,
|
||||
)
|
||||
112
api/core/agent/agent/openai_function_call.py
Normal file
112
api/core/agent/agent/openai_function_call.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||
_format_intermediate_steps
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
132
api/core/agent/agent/openai_function_call_summarize_mixin.py
Normal file
132
api/core/agent/agent/openai_function_call_summarize_mixin.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import cast, List
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_message_to_dict
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
|
||||
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
|
||||
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
llm = cast(ChatOpenAI, llm)
|
||||
model, encoding = llm._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
102
api/core/agent/agent/openai_multi_function_call.py
Normal file
102
api/core/agent/agent/openai_multi_function_call.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain.agents import BaseMultiActionAgent
|
||||
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
|
||||
_parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseMultiActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
# get current time
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
29
api/core/agent/agent/output_parser/structured_chat.py
Normal file
29
api/core/agent/agent/output_parser/structured_chat.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser, \
|
||||
logger
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
try:
|
||||
action_match = re.search(r"```(.*?)\n(.*?)```?", text, re.DOTALL)
|
||||
if action_match is not None:
|
||||
response = json.loads(action_match.group(2).strip(), strict=False)
|
||||
if isinstance(response, list):
|
||||
# gpt turbo frequently ignores the directive to emit a single action
|
||||
logger.warning("Got multiple action responses: %s", response)
|
||||
response = response[0]
|
||||
if response["action"] == "Final Answer":
|
||||
return AgentFinish({"output": response["action_input"]}, text)
|
||||
else:
|
||||
return AgentAction(
|
||||
response["action"], response.get("action_input", {}), text
|
||||
)
|
||||
else:
|
||||
return AgentFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||
187
api/core/agent/agent/structured_chat.py
Normal file
187
api/core/agent/agent/structured_chat.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
|
||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||
messages = []
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
||||
|
||||
self.moving_summary_index = len(intermediate_steps)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
if 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,89 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
|
||||
class AgentBuilder:
|
||||
@classmethod
|
||||
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
|
||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
||||
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=agent_loop_gather_callback_handler.model_name,
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
callback_manager=llm_callback_manager
|
||||
)
|
||||
|
||||
tool_callback_manager = CallbackManager([
|
||||
agent_loop_gather_callback_handler,
|
||||
dataset_tool_callback_handler,
|
||||
DifyStdOutCallbackHandler()
|
||||
])
|
||||
|
||||
for tool in tools:
|
||||
tool.callback_manager = tool_callback_manager
|
||||
|
||||
prompt = cls.build_agent_prompt_template(
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
agent_llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
|
||||
|
||||
agent_callback_manager = CallbackManager(
|
||||
[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_chain = AgentExecutor.from_agent_and_tools(
|
||||
tools=tools,
|
||||
agent=agent,
|
||||
memory=memory,
|
||||
callback_manager=agent_callback_manager,
|
||||
max_iterations=6,
|
||||
early_stopping_method="generate",
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
)
|
||||
|
||||
return agent_chain
|
||||
|
||||
@classmethod
|
||||
def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
|
||||
if memory:
|
||||
prompt = ConversationalAgent.create_prompt(
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
|
||||
if memory:
|
||||
agent = ConversationalAgent(
|
||||
llm_chain=agent_llm_chain
|
||||
)
|
||||
else:
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=agent_llm_chain
|
||||
)
|
||||
|
||||
return agent
|
||||
122
api/core/agent/agent_executor.py
Normal file
122
api/core/agent/agent_executor.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import Union, Optional
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class PlanningStrategy(str, enum.Enum):
|
||||
ROUTER = 'router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
MULTI_FUNCTION_CALL = 'multi_function_call'
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
llm: BaseLanguageModel
|
||||
tools: list[BaseTool]
|
||||
summary_llm: BaseLanguageModel
|
||||
dataset_llm: BaseLanguageModel
|
||||
memory: Optional[BaseChatMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
max_execution_time: Optional[float] = None
|
||||
early_stopping_method: str = "generate"
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentExecuteResult(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
output: Optional[str]
|
||||
configuration: AgentConfiguration
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
def __init__(self, configuration: AgentConfiguration):
|
||||
self.configuration = configuration
|
||||
self.agent = self._init_agent()
|
||||
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
|
||||
llm=self.configuration.llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_llm,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
llm=self.configuration.dataset_llm,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
verbose=True
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
|
||||
|
||||
return agent
|
||||
|
||||
def should_use_agent(self, query: str) -> bool:
|
||||
return self.agent.should_use_agent(query)
|
||||
|
||||
def run(self, query: str) -> AgentExecuteResult:
|
||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||
agent=self.agent,
|
||||
tools=self.configuration.tools,
|
||||
memory=self.configuration.memory,
|
||||
max_iterations=self.configuration.max_iterations,
|
||||
max_execution_time=self.configuration.max_execution_time,
|
||||
early_stopping_method=self.configuration.early_stopping_method,
|
||||
callbacks=self.configuration.callbacks
|
||||
)
|
||||
|
||||
try:
|
||||
output = agent_executor.run(query)
|
||||
except Exception:
|
||||
logging.exception("agent_executor run failed")
|
||||
output = None
|
||||
|
||||
return AgentExecuteResult(
|
||||
output=output,
|
||||
strategy=self.configuration.strategy,
|
||||
configuration=self.configuration
|
||||
)
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
@@ -12,6 +14,7 @@ from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
@@ -19,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
self.current_chain = None
|
||||
|
||||
@property
|
||||
@@ -28,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
def clear_agent_loops(self) -> None:
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
@@ -60,13 +65,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
# kwargs={}
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
if 'function_call' in completion_message.additional_kwargs:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
|
||||
else:
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
else:
|
||||
self._current_loop.completion = completion_generation.text
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
@@ -74,21 +87,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
@@ -107,15 +106,29 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
tool_input = json.dumps({"query": action.tool_input}
|
||||
if isinstance(action.tool_input, str) else action.tool_input)
|
||||
completion = None
|
||||
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
|
||||
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
|
||||
thought = action.log.strip()
|
||||
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
|
||||
else:
|
||||
action_name_position = action.log.index("Action:") if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
|
||||
if self._current_loop and self._current_loop.status == 'llm_end':
|
||||
self._current_loop.status = 'agent_action'
|
||||
self._current_loop.thought = thought
|
||||
self._current_loop.tool_name = tool
|
||||
self._current_loop.tool_input = tool_input
|
||||
if completion is not None:
|
||||
self._current_loop.completion = completion
|
||||
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
@@ -138,10 +151,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_name, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
@@ -150,16 +166,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
self._message_agent_thought = None
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
@@ -169,10 +176,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.completed = True
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
self._current_loop.thought = '[DONE]'
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_name, self._current_loop
|
||||
)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
elif not self._current_loop and self._agent_loops:
|
||||
self._agent_loops[-1].status = 'agent_finish'
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
@@ -11,6 +11,7 @@ from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
@@ -43,9 +44,11 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_name = serialized.get('name')
|
||||
dataset_id = tool_name[len("dataset-"):]
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
|
||||
# tool_name = serialized.get('name')
|
||||
input_dict = json.loads(input_str.replace("'", "\""))
|
||||
dataset_id = input_dict.get('dataset_id')
|
||||
query = input_dict.get('query')
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
@@ -66,52 +69,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
|
||||
@@ -10,9 +10,9 @@ class AgentLoop(BaseModel):
|
||||
tool_output: str = None
|
||||
|
||||
prompt: str = None
|
||||
prompt_tokens: int = None
|
||||
prompt_tokens: int = 0
|
||||
completion: str = None
|
||||
completion_tokens: int = None
|
||||
completion_tokens: int = 0
|
||||
|
||||
latency: float = None
|
||||
|
||||
|
||||
@@ -1,39 +1,26 @@
|
||||
from llama_index import Response
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class IndexToolCallbackHandler:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._response = None
|
||||
|
||||
@property
|
||||
def response(self) -> Response:
|
||||
return self._response
|
||||
|
||||
def on_tool_end(self, response: Response) -> None:
|
||||
"""Handle tool end."""
|
||||
self._response = response
|
||||
|
||||
|
||||
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, dataset_id: str) -> None:
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
|
||||
def on_tool_end(self, response: Response) -> None:
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
for node in response.source_nodes:
|
||||
index_node_id = node.node.doc_id
|
||||
for document in documents:
|
||||
doc_id = document.metadata['doc_id']
|
||||
|
||||
# add hit count to document segment
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset_id,
|
||||
DocumentSegment.index_node_id == index_node_id
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def __init__(self, llm: BaseLanguageModel,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.llm = llm
|
||||
self.llm_message = LLMMessage()
|
||||
@@ -25,41 +24,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
self.start_at = time.perf_counter()
|
||||
real_prompts = []
|
||||
for message in messages[0]:
|
||||
if message.type == 'human':
|
||||
role = 'user'
|
||||
elif message.type == 'ai':
|
||||
role = 'assistant'
|
||||
else:
|
||||
role = 'system'
|
||||
|
||||
real_prompts.append({
|
||||
"role": role,
|
||||
"text": message.content
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
if 'Chat' in serialized['name']:
|
||||
real_prompts = []
|
||||
messages = []
|
||||
for prompt in prompts:
|
||||
role, content = prompt.split(': ', maxsplit=1)
|
||||
if role == 'human':
|
||||
role = 'user'
|
||||
message = HumanMessage(content=content)
|
||||
elif role == 'ai':
|
||||
role = 'assistant'
|
||||
message = AIMessage(content=content)
|
||||
else:
|
||||
message = SystemMessage(content=content)
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
}]
|
||||
|
||||
real_prompt = {
|
||||
"role": role,
|
||||
"text": content
|
||||
}
|
||||
real_prompts.append(real_prompt)
|
||||
messages.append(message)
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
|
||||
else:
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
}]
|
||||
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
end_at = time.perf_counter()
|
||||
@@ -68,14 +67,18 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
try:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
except ConversationTaskStoppedException as ex:
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
self.llm_message.completion += token
|
||||
|
||||
def on_llm_error(
|
||||
@@ -90,58 +93,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
else:
|
||||
logging.error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
@@ -14,21 +13,20 @@ from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
|
||||
llm_constant.agent_model_name,
|
||||
conversation_message_task
|
||||
)
|
||||
self.agent_callback = None
|
||||
|
||||
def clear_chain_results(self) -> None:
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.agent_loop_gather_callback_handler.current_chain = None
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
@@ -50,13 +48,16 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
if not self._current_chain_result:
|
||||
self._current_chain_result = ChainResult(
|
||||
type=serialized['name'],
|
||||
prompt=inputs,
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
|
||||
chain_type = serialized['id'][-1]
|
||||
if chain_type:
|
||||
self._current_chain_result = ChainResult(
|
||||
type=chain_type,
|
||||
prompt=inputs,
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = self._current_chain_message
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
@@ -74,64 +75,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
self.clear_chain_results()
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.error(error)
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
self.clear_chain_results()
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
|
||||
|
||||
|
||||
class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Initialize callback handler."""
|
||||
self.color = color
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
print_text("\n[on_chat_model_start]\n", color='blue')
|
||||
for sub_messages in messages:
|
||||
for sub_message in sub_messages:
|
||||
print_text(str(sub_message) + "\n", color='blue')
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print_text("\n[on_llm_start]\n", color='blue')
|
||||
|
||||
if 'Chat' in serialized['name']:
|
||||
for prompt in prompts:
|
||||
print_text(prompt + "\n", color='blue')
|
||||
else:
|
||||
print_text(prompts[0] + "\n", color='blue')
|
||||
print_text(prompts[0] + "\n", color='blue')
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized["name"]
|
||||
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
chain_type = serialized['id'][-1]
|
||||
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Run on agent end."""
|
||||
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
|
||||
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.chain.tool_chain import ToolChain
|
||||
|
||||
|
||||
class ChainBuilder:
|
||||
@classmethod
|
||||
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
|
||||
return ToolChain(
|
||||
tool=tool,
|
||||
input_key=kwargs.get('input_key', 'input'),
|
||||
output_key=kwargs.get('output_key', 'tool_output'),
|
||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
|
||||
SensitiveWordAvoidanceChain]:
|
||||
sensitive_words = tool_config.get("words", "")
|
||||
if tool_config.get("enabled", False) \
|
||||
and sensitive_words:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
sensitive_words=sensitive_words.split(","),
|
||||
canned_response=tool_config.get("canned_response", ''),
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Base classes for LLM-powered router chains."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
|
||||
|
||||
|
||||
class Route(NamedTuple):
|
||||
destination: Optional[str]
|
||||
next_inputs: Dict[str, Any]
|
||||
|
||||
|
||||
class LLMRouterChain(Chain):
|
||||
"""A router chain that uses an LLM chain to perform routing."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM chain used to perform routing"""
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt(cls, values: dict) -> dict:
|
||||
prompt = values["llm_chain"].prompt
|
||||
if prompt.output_parser is None:
|
||||
raise ValueError(
|
||||
"LLMRouterChain requires base llm_chain prompt to have an output"
|
||||
" parser that converts LLM text output to a dictionary with keys"
|
||||
" 'destination' and 'next_inputs'. Received a prompt with no output"
|
||||
" parser."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the LLM chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.llm_chain.input_keys
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
super()._validate_outputs(outputs)
|
||||
if not isinstance(outputs["next_inputs"], dict):
|
||||
raise ValueError
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
output = cast(
|
||||
Dict[str, Any],
|
||||
self.llm_chain.predict_and_parse(**inputs),
|
||||
)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
|
||||
) -> LLMRouterChain:
|
||||
"""Convenience constructor."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return ["destination", "next_inputs"]
|
||||
|
||||
def route(self, inputs: Dict[str, Any]) -> Route:
|
||||
result = self(inputs)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
|
||||
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
||||
"""Parser for output of router chain int he multi-prompt chain."""
|
||||
|
||||
default_destination: str = "DEFAULT"
|
||||
next_inputs_type: Type = str
|
||||
next_inputs_inner_key: str = "input"
|
||||
|
||||
def parse_json_markdown(self, json_string: str) -> dict:
|
||||
# Remove the triple backticks if present
|
||||
start_index = json_string.find("```json")
|
||||
end_index = json_string.find("```", start_index + len("```json"))
|
||||
|
||||
if start_index != -1 and end_index != -1:
|
||||
extracted_content = json_string[start_index + len("```json"):end_index].strip()
|
||||
|
||||
# Parse the JSON string into a Python dictionary
|
||||
parsed = json.loads(extracted_content)
|
||||
else:
|
||||
raise Exception("Could not find JSON block in the output.")
|
||||
|
||||
return parsed
|
||||
|
||||
def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
|
||||
try:
|
||||
json_obj = self.parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{key}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
|
||||
def parse(self, text: str) -> Dict[str, Any]:
|
||||
try:
|
||||
expected_keys = ["destination", "next_inputs"]
|
||||
parsed = self.parse_and_check_json_markdown(text, expected_keys)
|
||||
if not isinstance(parsed["destination"], str):
|
||||
raise ValueError("Expected 'destination' to be a string.")
|
||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||
raise ValueError(
|
||||
f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
||||
)
|
||||
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
||||
if (
|
||||
parsed["destination"].strip().lower()
|
||||
== self.default_destination.lower()
|
||||
):
|
||||
parsed["destination"] = None
|
||||
else:
|
||||
parsed["destination"] = parsed["destination"].strip()
|
||||
return parsed
|
||||
except Exception as e:
|
||||
raise OutputParserException(
|
||||
f"Parsing text\n{text}\n raised following error:\n{e}"
|
||||
)
|
||||
@@ -1,108 +0,0 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from langchain.callbacks import SharedCallbackManager, CallbackManager
|
||||
from langchain.chains import SequentialChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.chain_builder import ChainBuilder
|
||||
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MainChainBuilder:
|
||||
@classmethod
|
||||
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
first_input_key = "input"
|
||||
final_output_key = "output"
|
||||
|
||||
chains = []
|
||||
|
||||
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
# agent mode
|
||||
tool_chains, chains_output_key = cls.get_agent_chains(
|
||||
tenant_id=tenant_id,
|
||||
agent_mode=agent_mode,
|
||||
memory=memory,
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
chains += tool_chains
|
||||
|
||||
if chains_output_key:
|
||||
final_output_key = chains_output_key
|
||||
|
||||
if len(chains) == 0:
|
||||
return None
|
||||
|
||||
for chain in chains:
|
||||
# do not add handler into singleton callback manager
|
||||
if not isinstance(chain.callback_manager, SharedCallbackManager):
|
||||
chain.callback_manager.add_handler(chain_callback_handler)
|
||||
|
||||
# build main chain
|
||||
overall_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=[first_input_key],
|
||||
output_variables=[final_output_key],
|
||||
memory=memory, # only for use the memory prompt input key
|
||||
)
|
||||
|
||||
return overall_chain
|
||||
|
||||
@classmethod
|
||||
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
# agent mode
|
||||
chains = []
|
||||
if agent_mode and agent_mode.get('enabled'):
|
||||
tools = agent_mode.get('tools', [])
|
||||
|
||||
pre_fixed_chains = []
|
||||
# agent_tools = []
|
||||
datasets = []
|
||||
for tool in tools:
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = list(tool.values())[0]
|
||||
if tool_type == 'sensitive-word-avoidance':
|
||||
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
|
||||
if chain:
|
||||
pre_fixed_chains.append(chain)
|
||||
elif tool_type == "dataset":
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
if dataset:
|
||||
datasets.append(dataset)
|
||||
|
||||
# add pre-fixed chains
|
||||
chains += pre_fixed_chains
|
||||
|
||||
if len(datasets) > 0:
|
||||
# tool to chain
|
||||
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
|
||||
tenant_id=tenant_id,
|
||||
datasets=datasets,
|
||||
conversation_message_task=conversation_message_task,
|
||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
|
||||
)
|
||||
chains.append(multi_dataset_router_chain)
|
||||
|
||||
final_output_key = cls.get_chains_output_key(chains)
|
||||
|
||||
return chains, final_output_key
|
||||
|
||||
@classmethod
|
||||
def get_chains_output_key(cls, chains: List[Chain]):
|
||||
if len(chains) > 0:
|
||||
return chains[-1].output_keys[0]
|
||||
return None
|
||||
@@ -1,140 +0,0 @@
|
||||
from typing import Mapping, List, Dict, Any, Optional
|
||||
|
||||
from langchain import LLMChain, PromptTemplate, ConversationChain
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from pydantic import Extra
|
||||
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.tool.dataset_tool_builder import DatasetToolBuilder
|
||||
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
||||
from models.dataset import Dataset
|
||||
|
||||
MULTI_PROMPT_ROUTER_TEMPLATE = """
|
||||
Given a raw text input to a language model select the model prompt best suited for \
|
||||
the input. You will be given the names of the available prompts and a description of \
|
||||
what the prompt is best suited for. You may also revise the original input if you \
|
||||
think that revising it will ultimately lead to a better response from the language \
|
||||
model.
|
||||
|
||||
<< FORMATTING >>
|
||||
Return a markdown code snippet with a JSON object formatted to look like:
|
||||
```json
|
||||
{{{{
|
||||
"destination": string \\ name of the prompt to use or "DEFAULT"
|
||||
"next_inputs": string \\ a potentially modified version of the original input
|
||||
}}}}
|
||||
```
|
||||
|
||||
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
|
||||
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
|
||||
REMEMBER: "next_inputs" can just be the original input if you don't think any \
|
||||
modifications are needed.
|
||||
|
||||
<< CANDIDATE PROMPTS >>
|
||||
{destinations}
|
||||
|
||||
<< INPUT >>
|
||||
{{input}}
|
||||
|
||||
<< OUTPUT >>
|
||||
"""
|
||||
|
||||
|
||||
class MultiDatasetRouterChain(Chain):
|
||||
"""Use a single chain to route an input to one of multiple candidate chains."""
|
||||
|
||||
router_chain: LLMRouterChain
|
||||
"""Chain for deciding a destination chain and the input to it."""
|
||||
dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
|
||||
"""Map of name to candidate chains that inputs can be routed to."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the router chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.router_chain.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return ["text"]
|
||||
|
||||
@classmethod
|
||||
def from_datasets(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
datasets: List[Dataset],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Convenience constructor for instantiating from destination prompts."""
|
||||
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
callback_manager=llm_callback_manager
|
||||
)
|
||||
|
||||
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
|
||||
else ('useful for when you want to answer queries about the ' + d.name))
|
||||
for d in datasets]
|
||||
destinations_str = "\n".join(destinations)
|
||||
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
|
||||
destinations=destinations_str
|
||||
)
|
||||
router_prompt = PromptTemplate(
|
||||
template=router_template,
|
||||
input_variables=["input"],
|
||||
output_parser=RouterOutputParser(),
|
||||
)
|
||||
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
||||
dataset_tools = {}
|
||||
for dataset in datasets:
|
||||
dataset_tool = DatasetToolBuilder.build_dataset_tool(
|
||||
dataset=dataset,
|
||||
response_mode='no_synthesizer', # "compact"
|
||||
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
|
||||
)
|
||||
dataset_tools[dataset.id] = dataset_tool
|
||||
return cls(
|
||||
router_chain=router_chain,
|
||||
dataset_tools=dataset_tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
if len(self.dataset_tools) == 0:
|
||||
return {"text": ''}
|
||||
elif len(self.dataset_tools) == 1:
|
||||
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
|
||||
|
||||
route = self.router_chain.route(inputs)
|
||||
|
||||
if not route.destination:
|
||||
return {"text": ''}
|
||||
elif route.destination in self.dataset_tools:
|
||||
return {"text": self.dataset_tools[route.destination].run(
|
||||
route.next_inputs['input']
|
||||
)}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Received invalid destination chain name '{route.destination}'"
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
|
||||
return self.canned_response
|
||||
return text
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
text = inputs[self.input_key]
|
||||
output = self._check_sensitive_word(text)
|
||||
return {self.output_key: output}
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ToolChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
tool: BaseTool
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "tool_chain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
input = inputs[self.input_key]
|
||||
output = self.tool.run(input, self.verbose)
|
||||
return {self.output_key: output}
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
input = inputs[self.input_key]
|
||||
output = await self.tool.arun(input, self.verbose)
|
||||
return {self.output_key: output}
|
||||
@@ -1,40 +1,43 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
|
||||
from langchain.schema import BaseMessage, HumanMessage
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.constant import llm_constant
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
||||
DifyStdOutCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.error import LLMBadRequestError
|
||||
from core.llm.fake import FakeLLM
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.chain.main_chain_builder import MainChainBuilder
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBStringBufferSharedMemory
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
||||
|
||||
|
||||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
|
||||
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
|
||||
query = PromptBuilder.process_template(query)
|
||||
|
||||
memory = None
|
||||
if conversation:
|
||||
@@ -48,6 +51,14 @@ class Completion:
|
||||
|
||||
inputs = conversation.inputs
|
||||
|
||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||
mode=app.mode,
|
||||
tenant_id=app.tenant_id,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs
|
||||
)
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
app=app,
|
||||
@@ -60,17 +71,33 @@ class Completion:
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# build main chain include agent
|
||||
main_chain = MainChainBuilder.to_langchain_components(
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
# init orchestrator rule parser
|
||||
orchestrator_rule_parser = OrchestratorRuleParser(
|
||||
tenant_id=app.tenant_id,
|
||||
agent_mode=app_model_config.agent_mode_dict,
|
||||
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
|
||||
conversation_message_task=conversation_message_task
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
chain_output = ''
|
||||
if main_chain:
|
||||
chain_output = main_chain.run(query)
|
||||
# parse sensitive_word_avoidance_chain
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback
|
||||
)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query)
|
||||
|
||||
# run the final llm
|
||||
try:
|
||||
@@ -80,7 +107,7 @@ class Completion:
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
streaming=streaming
|
||||
@@ -95,9 +122,20 @@ class Completion:
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
chain_output: str,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
||||
final_llm = FakeLLM(response=agent_execute_result.output,
|
||||
origin_llm=agent_execute_result.configuration.llm,
|
||||
streaming=streaming)
|
||||
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
||||
response = final_llm.generate([[HumanMessage(content=query)]])
|
||||
return response
|
||||
|
||||
final_llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict,
|
||||
@@ -108,17 +146,19 @@ class Completion:
|
||||
prompt, stop_words = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
agent_execute_result=agent_execute_result,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
|
||||
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=final_llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode=mode
|
||||
)
|
||||
@@ -128,42 +168,31 @@ class Completion:
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
||||
chain_output: Optional[str],
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||
# disable template string in query
|
||||
query_params = OutLinePromptTemplate.from_template(template=query).input_variables
|
||||
if query_params:
|
||||
for query_param in query_params:
|
||||
if query_param not in inputs:
|
||||
inputs[query_param] = '{' + query_param + '}'
|
||||
|
||||
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
|
||||
if mode == 'completion':
|
||||
prompt_template = OutLinePromptTemplate.from_template(
|
||||
template=("""Use the following CONTEXT as your learned knowledge:
|
||||
[CONTEXT]
|
||||
{context}
|
||||
[END CONTEXT]
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
""" if chain_output else "")
|
||||
""" if agent_execute_result else "")
|
||||
+ (pre_prompt + "\n" if pre_prompt else "")
|
||||
+ "{query}\n"
|
||||
+ "{{query}}\n"
|
||||
)
|
||||
|
||||
if chain_output:
|
||||
inputs['context'] = chain_output
|
||||
context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables
|
||||
if context_params:
|
||||
for context_param in context_params:
|
||||
if context_param not in inputs:
|
||||
inputs[context_param] = '{' + context_param + '}'
|
||||
if agent_execute_result:
|
||||
inputs['context'] = agent_execute_result.output
|
||||
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_content = prompt_template.format(
|
||||
@@ -187,18 +216,19 @@ And answer according to the language of the user's question.
|
||||
|
||||
if pre_prompt:
|
||||
pre_prompt_inputs = {k: inputs[k] for k in
|
||||
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
if k in inputs}
|
||||
|
||||
if pre_prompt_inputs:
|
||||
human_inputs.update(pre_prompt_inputs)
|
||||
|
||||
if chain_output:
|
||||
human_inputs['context'] = chain_output
|
||||
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
|
||||
[CONTEXT]
|
||||
{context}
|
||||
[END CONTEXT]
|
||||
if agent_execute_result:
|
||||
human_inputs['context'] = agent_execute_result.output
|
||||
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
{{context}}
|
||||
</context>
|
||||
|
||||
When answer to user:
|
||||
- If you don't know, just say that you don't know.
|
||||
@@ -210,7 +240,7 @@ And answer according to the language of the user's question.
|
||||
if pre_prompt:
|
||||
human_message_prompt += pre_prompt
|
||||
|
||||
query_prompt = "\nHuman: {query}\nAI: "
|
||||
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
@@ -219,20 +249,17 @@ And answer according to the language of the user's question.
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
|
||||
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
|
||||
- memory.llm.max_tokens - curr_message_tokens
|
||||
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
|
||||
model_name = model['name']
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
rest_tokens = llm_constant.max_context_token_length[model_name] \
|
||||
- max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
|
||||
# disable template string in query
|
||||
histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables
|
||||
if histories_params:
|
||||
for histories_param in histories_params:
|
||||
if histories_param not in human_inputs:
|
||||
human_inputs[histories_param] = '{' + histories_param + '}'
|
||||
|
||||
human_message_prompt += "\n\n" + histories
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
"inside <histories></histories> XML tags.\n\n<histories>\n"
|
||||
human_message_prompt += histories + "\n</histories>"
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
@@ -244,24 +271,24 @@ And answer according to the language of the user's question.
|
||||
|
||||
messages.append(human_message)
|
||||
|
||||
return messages, ['\nHuman:']
|
||||
for message in messages:
|
||||
message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
||||
|
||||
return messages, ['\nHuman:', '</histories>']
|
||||
|
||||
@classmethod
|
||||
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
streaming: bool,
|
||||
conversation_message_task: ConversationMessageTask) -> CallbackManager:
|
||||
def get_llm_callbacks(cls, llm: BaseLanguageModel,
|
||||
streaming: bool,
|
||||
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||
if streaming:
|
||||
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
||||
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
||||
else:
|
||||
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
|
||||
|
||||
return CallbackManager(callback_handlers)
|
||||
return [llm_callback_handler, DifyStdOutCallbackHandler()]
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> \
|
||||
str:
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
@@ -293,29 +320,51 @@ And answer according to the language of the user's question.
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
|
||||
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict) -> int:
|
||||
llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict
|
||||
)
|
||||
|
||||
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
max_tokens = llm.max_tokens
|
||||
model_name = app_model_config.model_dict.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
|
||||
|
||||
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
|
||||
raise LLMBadRequestError("Query is too long")
|
||||
# get prompt without memory and context
|
||||
prompt, _ = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
|
||||
else llm.get_num_tokens_from_messages(prompt)
|
||||
|
||||
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
|
||||
return rest_tokens
|
||||
|
||||
@classmethod
|
||||
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
|
||||
prompt: Union[str, List[BaseMessage]], mode: str):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
|
||||
max_tokens = final_llm.max_tokens
|
||||
model_name = model.get("name")
|
||||
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||
max_tokens = model.get("completion_params").get('max_tokens')
|
||||
|
||||
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
||||
prompt_tokens = final_llm.get_num_tokens(prompt)
|
||||
else:
|
||||
prompt_tokens = final_llm.get_messages_tokens(prompt)
|
||||
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
|
||||
|
||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||
@@ -324,9 +373,10 @@ And answer according to the language of the user's question.
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
|
||||
llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=app.tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
model=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
@@ -334,10 +384,12 @@ And answer according to the language of the user's question.
|
||||
original_prompt, _ = cls.get_main_llm_prompt(
|
||||
mode="completion",
|
||||
llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
pre_prompt=pre_prompt,
|
||||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
chain_output=None,
|
||||
agent_execute_result=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
@@ -360,10 +412,11 @@ And answer according to the language of the user's question.
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
|
||||
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=llm,
|
||||
model=app_model_config.model_dict,
|
||||
prompt=prompt,
|
||||
mode='completion'
|
||||
)
|
||||
|
||||
@@ -1,39 +1,52 @@
|
||||
from _decimal import Decimal
|
||||
|
||||
models = {
|
||||
'claude-instant-1': 'anthropic', # 100,000 tokens
|
||||
'claude-2': 'anthropic', # 100,000 tokens
|
||||
'gpt-4': 'openai', # 8,192 tokens
|
||||
'gpt-4-32k': 'openai', # 32,768 tokens
|
||||
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k': 'openai', # 16384 tokens
|
||||
'text-davinci-003': 'openai', # 4,097 tokens
|
||||
'text-davinci-002': 'openai', # 4,097 tokens
|
||||
'text-curie-001': 'openai', # 2,049 tokens
|
||||
'text-babbage-001': 'openai', # 2,049 tokens
|
||||
'text-ada-001': 'openai', # 2,049 tokens
|
||||
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
|
||||
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
|
||||
'whisper-1': 'openai'
|
||||
}
|
||||
|
||||
max_context_token_length = {
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
'text-davinci-002': 4097,
|
||||
'text-curie-001': 2049,
|
||||
'text-babbage-001': 2049,
|
||||
'text-ada-001': 2049,
|
||||
'text-embedding-ada-002': 8191
|
||||
'text-embedding-ada-002': 8191,
|
||||
}
|
||||
|
||||
models_by_mode = {
|
||||
'chat': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
'text-davinci-002' # 4,097 tokens
|
||||
'text-curie-001', # 2,049 tokens
|
||||
@@ -48,6 +61,14 @@ models_by_mode = {
|
||||
model_currency = 'USD'
|
||||
|
||||
model_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': Decimal('0.00163'),
|
||||
'completion': Decimal('0.00551'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': Decimal('0.01102'),
|
||||
'completion': Decimal('0.03268'),
|
||||
},
|
||||
'gpt-4': {
|
||||
'prompt': Decimal('0.03'),
|
||||
'completion': Decimal('0.06'),
|
||||
@@ -57,9 +78,13 @@ model_prices = {
|
||||
'completion': Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': Decimal('0.002'),
|
||||
'prompt': Decimal('0.0015'),
|
||||
'completion': Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': Decimal('0.003'),
|
||||
'completion': Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': Decimal('0.02'),
|
||||
'completion': Decimal('0.02')
|
||||
@@ -77,7 +102,7 @@ model_prices = {
|
||||
'completion': Decimal('0.0004')
|
||||
},
|
||||
'text-embedding-ada-002': {
|
||||
'usage': Decimal('0.0004'),
|
||||
'usage': Decimal('0.0001'),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -52,11 +52,11 @@ class ConversationMessageTask:
|
||||
message=self.message,
|
||||
conversation=self.conversation,
|
||||
chain_pub=False, # disabled currently
|
||||
agent_thought_pub=False # disabled currently
|
||||
agent_thought_pub=True
|
||||
)
|
||||
|
||||
def init(self):
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
|
||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
|
||||
self.model_dict['provider'] = provider_name
|
||||
|
||||
override_model_configs = None
|
||||
@@ -69,6 +69,7 @@ class ConversationMessageTask:
|
||||
"suggested_questions": self.app_model_config.suggested_questions_list,
|
||||
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
|
||||
"more_like_this": self.app_model_config.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
|
||||
"user_input_form": self.app_model_config.user_input_form_list,
|
||||
}
|
||||
|
||||
@@ -78,7 +79,7 @@ class ConversationMessageTask:
|
||||
if self.mode == 'chat':
|
||||
introduction = self.app_model_config.opening_statement
|
||||
if introduction:
|
||||
prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction))
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
|
||||
try:
|
||||
introduction = prompt_template.format(**prompt_inputs)
|
||||
@@ -86,11 +87,10 @@ class ConversationMessageTask:
|
||||
pass
|
||||
|
||||
if self.app_model_config.pre_prompt:
|
||||
pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt)
|
||||
system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs)
|
||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
||||
system_instruction = system_message.content
|
||||
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
|
||||
system_instruction_tokens = llm.get_messages_tokens([system_message])
|
||||
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
|
||||
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
@@ -157,7 +157,7 @@ class ConversationMessageTask:
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.answer = llm_message.completion.strip() if llm_message.completion else ''
|
||||
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.provider_response_latency = llm_message.latency
|
||||
@@ -186,6 +186,7 @@ class ConversationMessageTask:
|
||||
if provider and provider.provider_type == ProviderType.SYSTEM.value:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.app.tenant_id,
|
||||
Provider.provider_name == provider.provider_name,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + 1})
|
||||
|
||||
@@ -207,7 +208,28 @@ class ConversationMessageTask:
|
||||
|
||||
self._pub_handler.pub_chain(message_chain)
|
||||
|
||||
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str,
|
||||
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
|
||||
message_agent_thought = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
answer=agent_loop.completion,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.flush()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
|
||||
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
|
||||
@@ -222,34 +244,18 @@ class ConversationMessageTask:
|
||||
agent_answer_unit_price
|
||||
)
|
||||
|
||||
message_agent_loop = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
observation=agent_loop.tool_output,
|
||||
tool_process_data='', # currently not support
|
||||
message=agent_loop.prompt,
|
||||
message_token=loop_message_tokens,
|
||||
message_unit_price=agent_message_unit_price,
|
||||
answer=agent_loop.completion,
|
||||
answer_token=loop_answer_tokens,
|
||||
answer_unit_price=agent_answer_unit_price,
|
||||
latency=agent_loop.latency,
|
||||
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
|
||||
total_price=loop_total_price,
|
||||
currency=llm_constant.model_currency,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_loop)
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = llm_constant.model_currency
|
||||
db.session.flush()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_loop)
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_query_obj.dataset_id,
|
||||
@@ -293,12 +299,12 @@ class PubHandler:
|
||||
if not user:
|
||||
raise ValueError("user is required")
|
||||
|
||||
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
|
||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||
return "generate_result:{}-{}".format(user_str, task_id)
|
||||
|
||||
@classmethod
|
||||
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
|
||||
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
|
||||
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||
return "generate_result_stopped:{}-{}".format(user_str, task_id)
|
||||
|
||||
def pub_text(self, text: str):
|
||||
@@ -306,10 +312,10 @@ class PubHandler:
|
||||
'event': 'message',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'message_id': str(self._message.id),
|
||||
'text': text,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
'conversation_id': str(self._conversation.id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,16 +352,14 @@ class PubHandler:
|
||||
content = {
|
||||
'event': 'agent_thought',
|
||||
'data': {
|
||||
'id': message_agent_thought.id,
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_agent_thought.message_chain_id,
|
||||
'agent_thought_id': message_agent_thought.id,
|
||||
'position': message_agent_thought.position,
|
||||
'thought': message_agent_thought.thought,
|
||||
'tool': message_agent_thought.tool,
|
||||
'tool_input': message_agent_thought.tool_input,
|
||||
'observation': message_agent_thought.observation,
|
||||
'answer': message_agent_thought.answer,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
@@ -388,6 +392,15 @@ class PubHandler:
|
||||
def _is_stopped(self):
|
||||
return redis_client.get(self._stopped_cache_key) is not None
|
||||
|
||||
@classmethod
|
||||
def ping(cls, user: Union[Account | EndUser], task_id: str):
|
||||
content = {
|
||||
'event': 'ping'
|
||||
}
|
||||
|
||||
channel = cls.generate_channel_name(user, task_id)
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def stop(cls, user: Union[Account | EndUser], task_id: str):
|
||||
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
|
||||
|
||||
66
api/core/data_loader/file_extractor.py
Normal file
66
api/core/data_loader/file_extractor.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import requests
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv import CSVLoader
|
||||
from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
return cls.load_from_file(file_path, return_text, upload_file)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
|
||||
return cls.load_from_file(file_path, return_text)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
if input_file.suffix == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif input_file.suffix == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif input_file.suffix in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif input_file.suffix in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif input_file.suffix == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif input_file.suffix == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
67
api/core/data_loader/loader/csv.py
Normal file
67
api/core/data_loader/loader/csv.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
|
||||
from models.dataset import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CSVLoader(LCCSVLoader):
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[Dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.source_column = source_column
|
||||
self.encoding = encoding
|
||||
self.csv_args = csv_args or {}
|
||||
self.autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self.autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self.file_path)
|
||||
for encoding in detected_encodings:
|
||||
logger.debug("Trying encoding: ", encoding.encoding)
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def _read_from_file(self, csvfile):
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else ''
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
45
api/core/data_loader/loader/excel.py
Normal file
45
api/core/data_loader/loader/excel.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from openpyxl.reader.excel import load_workbook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExcelLoader(BaseLoader):
|
||||
"""Load xlxs files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
data = []
|
||||
keys = []
|
||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||
# loop over all sheets
|
||||
for sheet in wb:
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
if all(v is None for v in row):
|
||||
continue
|
||||
if keys == []:
|
||||
keys = list(map(str, row))
|
||||
else:
|
||||
row_dict = dict(zip(keys, list(map(str, row))))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
|
||||
document = Document(page_content=item, metadata={'source': self._file_path})
|
||||
data.append(document)
|
||||
|
||||
return data
|
||||
35
api/core/data_loader/loader/html.py
Normal file
35
api/core/data_loader/loader/html.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTMLLoader(BaseLoader):
|
||||
"""Load html files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
return [Document(page_content=self._load_as_text())]
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
with open(self._file_path, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
text = soup.get_text()
|
||||
text = text.strip() if text else ''
|
||||
|
||||
return text
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user