Compare commits

..

82 Commits
0.6.0 ... 0.6.2

Author SHA1 Message Date
takatost
6fa0e4072d fix: yarn install extract package err when using GitHub Cache in amd6… (#3383) 2024-04-12 00:04:09 +08:00
takatost
e15d18aa1c version to 0.6.2-fix1 (#3380) 2024-04-11 23:38:29 +08:00
takatost
164ef26a60 fix: variable pool mapping variable mixed up (#3378) 2024-04-11 23:19:28 +08:00
takatost
0dada847ef version to 0.6.2 (#3375) 2024-04-11 22:10:45 +08:00
takatost
36b7dbb8d0 fix: cohere tool call does not support single tool (#3373) 2024-04-11 21:32:18 +08:00
Chenhe Gu
02e483c99b update workflow intro mp4 codec (#3372) 2024-04-11 21:24:22 +08:00
Chenhe Gu
afe30e15a0 Update README.md (#3371) 2024-04-11 21:06:20 +08:00
takatost
9a1ea9ac03 fix: image token calc of OpenAI Compatible API (#3368) 2024-04-11 20:29:48 +08:00
Yeuoly
693647a141 Fix/Bing Search url endpoint cannot be customized (#3366) 2024-04-11 19:56:08 +08:00
Yeuoly
cea107b165 Refactor/react agent (#3355) 2024-04-11 18:34:17 +08:00
Joel
509c640a80 fix: var name too long would break ui in var assigner and end nodes (#3361) 2024-04-11 18:19:33 +08:00
Lao
617e7cee81 Added a note on the front-end docker build: use taobao source to accelerate the installation of front-end dependency packages to achieve the purpose of quickly building containers (#3358)
Co-authored-by: lbm21 <313338264@qq.com>
Co-authored-by: akou <beiming1201@gmail.com>
2024-04-11 18:14:58 +08:00
Joel
d87d4b9b56 fix: remove middle editor may cause render placement error (#3356) 2024-04-11 17:51:14 +08:00
Jyong
c889717d24 Fix issue : don't delete DatasetProcessRule, DatasetQuery and AppDatasetJoin when delete dataset with no document (#3354) 2024-04-11 17:43:22 +08:00
Jyong
1f302990c6 add segment with keyword issue (#3351)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2024-04-11 16:57:02 +08:00
Jyong
37024afe9c fix issue: user’s keywords do not affect when add segment (#3349) 2024-04-11 16:34:52 +08:00
Yeuoly
18b855140d fix/moonshot-function-call (#3339) 2024-04-11 15:42:26 +08:00
crazywoola
7c520b52c1 feat: update aws bedrock (#3326)
Co-authored-by: chenhe <guchenhe@gmail.com>
2024-04-11 15:38:55 +08:00
Joel
b98e363a5c fix: leave progress page still call indexing-status api (#3345) 2024-04-11 15:38:38 +08:00
crazywoola
0a7ea9d206 Doc/update readme (#3344) 2024-04-11 15:15:07 +08:00
crazywoola
3d473b9763 feat: make input size bigger in start (#3340) 2024-04-11 15:06:55 +08:00
Eric Wang
e0df7505f6 feat(llm/models): add gemini-1.5-pro (#2925) 2024-04-11 10:58:13 +08:00
呆萌闷油瓶
43bb0b0b93 chore:bump pypdfium2 from 4.16.0 to 4.17.0 (#3310) 2024-04-11 09:13:03 +08:00
Jyong
6164604462 fix dataset retrival in dataset mode (#3334) 2024-04-11 02:11:21 +08:00
takatost
826c422ac4 feat: Add Cohere Command R / R+ model support (#3333) 2024-04-11 01:22:55 +08:00
Kenny
bf63a43bda feat: support gpt-4-turbo-2024-04-09 model (#3300) 2024-04-10 22:55:46 +08:00
Bowen Liang
55fc46c707 improvement: speed up dependency installation in docker image rebuilds by mounting cache layer (#3218) 2024-04-10 22:49:04 +08:00
呆萌闷油瓶
5102430a68 feat:add 'name' field return (#3152) 2024-04-10 22:34:43 +08:00
Lao
0f897bc1f9 feat: add missing workflow i18n keys (#3309)
Co-authored-by: lbm21 <313338264@qq.com>
2024-04-10 22:20:14 +08:00
Chenhe Gu
d948b0b49b add german translations (#3322) 2024-04-10 22:05:27 +08:00
Jyong
b6de97ad53 Remove langchain dataset retrival agent logic (#3311) 2024-04-10 20:37:22 +08:00
Chenhe Gu
8cefa6b82e Update README.md (#3281) 2024-04-10 20:10:21 +08:00
dependabot[bot]
81e1b3fc61 chore(deps): bump katex from 0.16.8 to 0.16.10 in /web (#3307)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-10 16:29:56 +08:00
Nite Knite
4c1cfd9278 chore: address security alerts on braces escape and KaTeX (#3301) 2024-04-10 16:16:24 +08:00
Yeuoly
14bb0b02ac Feat/Agent-Image-Processing (#3293)
Co-authored-by: Joel <iamjoel007@gmail.com>
2024-04-10 14:48:40 +08:00
zxhlyh
240c793e7a fix: variable-assigner node connect (#3288) 2024-04-10 13:49:21 +08:00
Joel
89a853212b fix: var assigner input node can not find caused error (#3274) 2024-04-10 11:16:54 +08:00
takatost
97d1e0bbbb feat: vision parameter support of OpenAI Compatible API (#3272) 2024-04-10 11:13:56 +08:00
takatost
cfb5ccc7d3 fix: image was sent to an unsupported LLM when sending second message (#3268) 2024-04-10 10:29:52 +08:00
Yeuoly
835e547195 feat: gpt-4-turbo (#3263) 2024-04-10 10:28:52 +08:00
zxhlyh
af9ccb7072 fix: agent chat multiple model debug (#3258) 2024-04-09 22:24:02 +08:00
takatost
74de7cf33c version to 0.6.1 (#3253) 2024-04-09 21:21:09 +08:00
crazywoola
f5e65b98a9 feat: remove unregistered-llm-in-debug (#3251) 2024-04-09 20:49:52 +08:00
Chenhe Gu
eb76d7a226 make sure validation flow works for all model providers in bedrock (#3250) 2024-04-09 20:42:18 +08:00
Yeuoly
e635f3dc1d chore: remove langchain in tools (#3247) 2024-04-09 19:28:22 +08:00
takatost
2a6b7d57cb fix: token is not logging of question classifier node (#3249) 2024-04-09 19:25:08 +08:00
Joel
8bb225bec6 fix: number type in app would render as select type in webapp (#3244) 2024-04-09 18:33:47 +08:00
zxhlyh
39d3fc4742 feat: prompt-editor support undo (#3242) 2024-04-09 18:18:53 +08:00
KVOJJJin
5c98260cec Fix: picture of workflow (#3241) 2024-04-09 17:49:45 +08:00
takatost
f599f41336 fix: empty conversation list of explore chatbot (#3235) 2024-04-09 17:04:48 +08:00
zxhlyh
1384a6d0fd fix: workflow run edge status (#3236) 2024-04-09 16:53:34 +08:00
Bowen Liang
28089c98c1 fix: skip Celery warning by setting broker_connection_retry_on_startup config (#3188) 2024-04-09 16:14:43 +08:00
minakokojima
10d6d50b6c update link (#3226) 2024-04-09 16:09:46 +08:00
Joel
752f6fb15a fix: file not uploaded caused api error (#3228) 2024-04-09 15:54:36 +08:00
Jyong
8fcf459285 fix milvus database name parameter missed (#3229) 2024-04-09 15:54:13 +08:00
Leo Q
9c01bcb3e5 feat: support setting database used in Milvus (#3003) 2024-04-09 15:39:36 +08:00
Yeuoly
a2c068d949 feat: moonshot function call (#3227) 2024-04-09 15:30:09 +08:00
takatost
4ad3f2cdc2 fix: image text when retrieve chat histories (#3220) 2024-04-09 15:20:45 +08:00
legao
29918c498c fixed the issue of missing cleanup function in the AudioBtn component (#3133) 2024-04-09 15:10:58 +08:00
Joel
269432a5e6 fix: vision config doesn't enabled in llm (#3225) 2024-04-09 15:07:43 +08:00
takatost
a33b774314 fix: latest image tag not push in GitHub action (#3213) 2024-04-09 14:35:39 +08:00
Yeuoly
cc5ccaaca1 fix: incomplete response (#3215) 2024-04-09 14:35:25 +08:00
Jyong
33ea689861 fix detached instance error in keyword index create thread and fix question classifier node out of index error (#3219) 2024-04-09 14:34:51 +08:00
Chenhe Gu
581836b716 Update README.md (#3212) 2024-04-09 14:34:42 +08:00
Bowen Liang
0516b78d6f fix: index number in api/README (#3214) 2024-04-09 13:59:26 +08:00
Jyong
84d7cbf916 fix economy index search in workflow (#3205) 2024-04-09 13:20:51 +08:00
Chenhe Gu
f514fd2182 Update README.md (#3206) 2024-04-09 12:30:44 +08:00
zxhlyh
86707928d4 fix: node connect self (#3194) 2024-04-09 12:24:41 +08:00
Eric Wang
3c3fb3cd3f fix(code_executor): surrogates not allowed error in jinja2 template (#3191) 2024-04-09 12:21:03 +08:00
Yeuoly
337899a03d Fix/code transform result (#3203) 2024-04-09 12:20:34 +08:00
Jat
bae0c071cd Fix: remove unavailable return_preamble parameter in cohere (#3201)
Signed-off-by: Jat <jat@sinosky.org>
2024-04-09 12:11:53 +08:00
Joel
19cb3c7871 fix: sometimes chosed old selected knowledge may overwirte the new knowledge (#3199) 2024-04-09 11:46:59 +08:00
Jyong
2e4dec365d Compatible with unique index conflicts (#3183) 2024-04-09 02:16:19 +08:00
Chenhe Gu
ca3e2e6cc0 Update README.md to include workflows (#3180) 2024-04-09 01:49:19 +08:00
Jyong
283979fc46 fix keyword index error when storage source is S3 (#3182) 2024-04-09 01:42:58 +08:00
takatost
a81c1ab6ae version to 0.6.0-fix1 (#3179) 2024-04-09 00:10:20 +08:00
KVOJJJin
48d4d55ecc Fix: features of agent-chat (#3178) 2024-04-08 23:53:59 +08:00
zxhlyh
b7691f5658 fix: prompt editor variable picker (#3177) 2024-04-08 23:53:09 +08:00
crazywoola
1382f10433 feat: translations (#3176) 2024-04-08 23:17:16 +08:00
KVOJJJin
d8db728c33 Fix: prompt of expert mode (#3168) 2024-04-08 21:36:27 +08:00
takatost
d2259f20cb fix: app export dsl not include desc (#3167) 2024-04-08 21:30:18 +08:00
takatost
9720d6b7a5 fix: metadata in generate npe issue (#3166) 2024-04-08 21:30:03 +08:00
215 changed files with 9280 additions and 2903 deletions

View File

@@ -46,7 +46,7 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' && startsWith(github.ref, 'refs/tags/') }}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@@ -36,7 +36,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Feature Type | Priority |
| ------------------------------------------------------------ | --------------- |
| High-Priority Features as being labeled by a team member | High Priority |
| Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority |
| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority |
| Non-core features and minor enhancements | Low Priority |
| Valuable but not immediate | Future-Feature |

View File

@@ -34,7 +34,7 @@
| Feature Type | Priority |
| ------------------------------------------------------------ | --------------- |
| High-Priority Features as being labeled by a team member | High Priority |
| Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority |
| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority |
| Non-core features and minor enhancements | Low Priority |
| Valuable but not immediate | Future-Feature |

271
README.md
View File

@@ -1,95 +1,176 @@
[![](./images/GitHub_README_cover.png)](https://dify.ai)
![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab)
<p align="center">
<a href="./README.md">English</a> |
<a href="./README_CN.md">简体中文</a> |
<a href="./README_JA.md">日本語</a> |
<a href="./README_ES.md">Español</a> |
<a href="./README_KL.md">Klingon</a> |
<a href="./README_FR.md">Français</a>
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
<a href="https://docs.dify.ai">Documentation</a> ·
<a href="https://cal.com/guchenhe/30min">Commercial inquiry</a>
</p>
<p align="center">
<a href="https://dify.ai" target="_blank">
<img alt="Static Badge" src="https://img.shields.io/badge/AI-Dify?logo=AI&logoColor=%20%23f5f5f5&label=Dify&labelColor=%20%23155EEF&color=%23EAECF0"></a>
<img alt="Static Badge" src="https://img.shields.io/badge/Product-F04438"></a>
<a href="https://dify.ai/pricing" target="_blank">
<img alt="Static Badge" src="https://img.shields.io/badge/free-pricing?logo=free&color=%20%23155EEF&label=pricing&labelColor=%20%23528bff"></a>
<a href="https://discord.gg/FngNHpbcY7" target="_blank">
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord"
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb"
alt="chat on Discord"></a>
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
<img src="https://img.shields.io/twitter/follow/dify_ai?style=social&logo=X"
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
alt="follow on Twitter"></a>
<a href="https://hub.docker.com/u/langgenius" target="_blank">
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
<img alt="Commits last month" src="https://img.shields.io/github/commit-activity/m/langgenius/dify?labelColor=%20%2332b583&color=%20%2312b76a"></a>
<a href="https://github.com/langgenius/dify/" target="_blank">
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
</p>
<p align="center">
<a href="https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6" target="_blank">
📌 Check out Dify Premium on AWS and deploy it to your own AWS VPC with one-click.
</a>
<a href="./README.md"><img alt="Commits last month" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README_CN.md"><img alt="Commits last month" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README_JA.md"><img alt="Commits last month" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README_ES.md"><img alt="Commits last month" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README_KL.md"><img alt="Commits last month" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README_FR.md"><img alt="Commits last month" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
</p>
**Dify** is an open-source LLM app development platform. Dify's intuitive interface combines a RAG pipeline, AI workflow orchestration, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
#
![](./images/demo.png)
<p align="center">
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p>
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
</br> </br>
**1. Workflow**:
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
## Using our Cloud Services
**2. Comprehensive model support**:
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabilities of the self-deployed version, and includes 200 free requests to OpenAI GPT-3.5.
![providers-v3](https://github.com/langgenius/dify/assets/13230914/55fab860-d818-4c95-95a2-7ac39f6aea83)
### Looking to purchase via AWS?
Check out [Dify Premium on AWS](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click.
## Dify vs. LangChain vs. Assistants API
**3. Prompt IDE**:
Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app.
| Feature | Dify.AI | Assistants API | LangChain |
|---------|---------|----------------|-----------|
| **Programming Approach** | API-oriented | API-oriented | Python Code-oriented |
| **Ecosystem Strategy** | Open Source | Close Source | Open Source |
| **RAG Engine** | Supported | Supported | Not Supported |
| **Prompt IDE** | Included | Included | None |
| **Supported LLMs** | Rich Variety | OpenAI-only | Rich Variety |
| **Local Deployment** | Supported | Not Supported | Not Applicable |
**4. RAG Pipeline**:
Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats.
**5. Agent capabilities**:
You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DELL·E, Stable Diffusion and WolframAlpha.
**6. LLMOps**:
Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations.
**7. Backend-as-a-Service**:
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
## Feature Comparison
<table style="width: 100%;">
<tr>
<th align="center">Feature</th>
<th align="center">Dify.AI</th>
<th align="center">LangChain</th>
<th align="center">Flowise</th>
<th align="center">OpenAI Assistants API</th>
</tr>
<tr>
<td align="center">Programming Approach</td>
<td align="center">API + App-oriented</td>
<td align="center">Python Code</td>
<td align="center">App-oriented</td>
<td align="center">API-oriented</td>
</tr>
<tr>
<td align="center">Supported LLMs</td>
<td align="center">Rich Variety</td>
<td align="center">Rich Variety</td>
<td align="center">Rich Variety</td>
<td align="center">OpenAI-only</td>
</tr>
<tr>
<td align="center">RAG Engine</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">✅</td>
</tr>
<tr>
<td align="center">Agent</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">✅</td>
</tr>
<tr>
<td align="center">Workflow</td>
<td align="center">✅</td>
<td align="center">❌</td>
<td align="center">✅</td>
<td align="center">❌</td>
</tr>
<tr>
<td align="center">Observability</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">❌</td>
<td align="center">❌</td>
</tr>
<tr>
<td align="center">Enterprise Feature (SSO/Access control)</td>
<td align="center">✅</td>
<td align="center">❌</td>
<td align="center">❌</td>
<td align="center">❌</td>
</tr>
<tr>
<td align="center">Local Deployment</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">✅</td>
<td align="center">❌</td>
</tr>
</table>
## Using Dify
- **Cloud </br>**
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
- **Self-hosting Dify Community Edition</br>**
Quickly get Dify running in your environment with this [starter guide](#quick-start).
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for Enterprise / Organizations</br>**
We provide additional enterprise-centric features. [Schedule a meeting with us](https://cal.com/guchenhe/30min) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
## Staying ahead
Star Dify on GitHub and be instantly notified of new releases.
![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4)
## Features
## Quick Start
> Before installing Dify, make sure your machine meets the following minimum system requirements:
>
>- CPU >= 2 Core
>- RAM >= 4GB
![](./images/models.png)
**1. LLM Support**: Integration with OpenAI's GPT family of models, or the open-source Llama2 family models. In fact, Dify supports mainstream commercial models and open-source models (locally deployed or based on MaaS).
**2. Prompt IDE**: Visual orchestration of applications and services based on LLMs with your team.
**3. RAG Engine**: Includes various RAG capabilities based on full-text indexing or vector database embeddings, allowing direct upload of PDFs, TXTs, and other text formats.
**4. AI Agent**: Based on Function Calling and ReAct, the Agent inference framework allows users to customize tools, what you see is what you get. Dify provides more than a dozen built-in tool calling capabilities, such as Google Search, DELL·E, Stable Diffusion, WolframAlpha, etc.
**5. Continuous Operations**: Monitor and analyze application logs and performance, continuously improving Prompts, datasets, or models using production data.
## Before You Start
**Star us on GitHub, and be instantly notified for new releases!**
![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f)
- [Website](https://dify.ai)
- [Docs](https://docs.dify.ai)
- [Deployment Docs](https://docs.dify.ai/getting-started/install-self-hosted)
- [FAQ](https://docs.dify.ai/getting-started/faq)
## Install the Community Edition
### System Requirements
Before installing Dify, make sure your machine meets the following minimum system requirements:
- CPU >= 2 Core
- RAM >= 4GB
### Quick Start
</br>
The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
@@ -98,61 +179,63 @@ cd docker
docker compose up -d
```
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization installation process.
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
#### Deploy with Helm Chart
> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
[Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
## Next steps
If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run `docker-compose up -d` again. You can see the full list of environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
### Configuration
If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run `docker-compose up -d` again. You can see the full list of environment variables in our [docs](https://docs.dify.ai/getting-started/install-self-hosted/environments).
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)
## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
### Projects made by community
- [Chatbot Chrome Extension by @charli117](https://github.com/langgenius/chatbot-chrome-extension)
> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
### Contributors
**Contributors**
<a href="https://github.com/langgenius/dify/graphs/contributors">
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
</a>
### Translations
## Community & Contact
We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
## Community & Support
* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and checking out our feature roadmap.
* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
* [Email Support](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI.
* [Email](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI.
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
* [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
* [Business Contact](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Best for: business inquiries of licensing Dify.AI for commercial use.
### Direct Meetings
Or, schedule a meeting directly with a team member:
**Help us make Dify better. Reach out directly to us**.
<table>
<tr>
<th>Point of Contact</th>
<th>Purpose</th>
</tr>
<tr>
<td><a href='https://cal.com/guchenhe/15min' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/9ebcd111-1205-4d71-83d5-948d70b809f5' alt='Git-Hub-README-Button-3x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
<td>Business enquiries & product feedback</td>
</tr>
<tr>
<td><a href='https://cal.com/pinkbanana' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/d1edd00a-d7e4-4513-be6c-e57038e143fd' alt='Git-Hub-README-Button-2x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
<td>Contributions, issues & feature requests</td>
</tr>
</table>
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)
| Point of Contact | Purpose |
| :----------------------------------------------------------: | :----------------------------------------------------------: |
| <a href='https://cal.com/guchenhe/15min' target='_blank'><img src='https://i.postimg.cc/fWBqSmjP/Git-Hub-README-Button-3x.png' border='0' alt='Git-Hub-README-Button-3x' height="60" width="214"/></a> | Product design feedback, user experience discussions, feature planning and roadmaps. |
| <a href='https://cal.com/pinkbanana' target='_blank'><img src='https://i.postimg.cc/LsRTh87D/Git-Hub-README-Button-2x.png' border='0' alt='Git-Hub-README-Button-2x' height="60" width="225"/></a> | Technical support, issues, or feature requests |
## Security Disclosure

View File

@@ -21,6 +21,10 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>
<p align="center">
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p>
<p align="center">
<a href="https://mp.weixin.qq.com/s/TnyfIuH-tPi9o1KNjwVArw" target="_blank">
Dify 发布 AI Agent 能力:基于不同的大型语言模型构建 GPTs 和 Assistants

View File

@@ -115,6 +115,7 @@ docker compose up -d
Difyに貢献していただき、コードの提出、問題の報告、新しいアイデアの提供、またはDifyを基に作成した興味深く有用なAIアプリケーションの共有により、Difyをより良いものにするお手伝いを歓迎します。同時に、さまざまなイベント、会議、ソーシャルメディアでDifyを共有することも歓迎します。
- [Github Discussion](https://github.com/langgenius/dify/discussions). 👉:アプリを共有し、コミュニティとコミュニケーション。
- [GitHub Issues](https://github.com/langgenius/dify/issues)。最適な使用法Dify.AIの使用中に遭遇するバグやエラー、[貢献ガイド](CONTRIBUTING.md)を参照。
- [Email サポート](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify)。最適な使用法Dify.AIの使用に関する質問。
- [Discord](https://discord.gg/FngNHpbcY7)。最適な使用法:アプリケーションの共有とコミュニティとの交流。

View File

@@ -11,7 +11,8 @@ RUN apt-get update \
COPY requirements.txt /requirements.txt
RUN pip install --prefix=/pkg -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --prefix=/pkg -r requirements.txt
# production stage
FROM base AS production

View File

@@ -17,16 +17,16 @@
```bash
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
3.5 If you use Anaconda, create a new environment and activate it
4. If you use Anaconda, create a new environment and activate it
```bash
conda create --name dify python=3.10
conda activate dify
```
4. Install dependencies
5. Install dependencies
```bash
pip install -r requirements.txt
```
5. Run migrate
6. Run migrate
Before the first launch, migrate the database to the latest version.
@@ -47,9 +47,11 @@
pip install -r requirements.txt --upgrade --force-reinstall
```
6. Start backend:
7. Start backend:
```bash
flask run --host 0.0.0.0 --port=5001 --debug
```
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
8. If you need to debug local async processing, you can run `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
8. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
9. If you need to debug local async processing, please start the worker service by running
`celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`.
The started celery app handles the async tasks, e.g. dataset importing and documents indexing.

View File

@@ -42,7 +42,7 @@ DEFAULTS = {
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-turbo-2024-04-09,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
@@ -67,6 +67,7 @@ DEFAULTS = {
'CODE_EXECUTION_ENDPOINT': '',
'CODE_EXECUTION_API_KEY': '',
'TOOL_ICON_CACHE_MAX_AGE': 3600,
'MILVUS_DATABASE': 'default',
'KEYWORD_DATA_SOURCE_TYPE': 'database',
}
@@ -98,7 +99,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.6.0"
self.CURRENT_VERSION = "0.6.2"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -212,6 +213,7 @@ class Config:
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')
# weaviate settings
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')

View File

@@ -6,6 +6,7 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
@@ -39,8 +40,8 @@ class ConversationListApi(InstalledAppResource):
user=current_user,
last_id=args['last_id'],
limit=args['limit'],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
exclude_debug_conversation=True
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@@ -1,14 +1,11 @@
import json
from flask import current_app
from flask_restful import fields, marshal_with, Resource
from flask_restful import Resource, fields, marshal_with
from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.model import App, AppModelConfig, AppMode
from models.tools import ApiToolProvider
from models.model import App, AppMode
from services.app_service import AppService
@@ -92,6 +89,16 @@ class AppMetaApi(Resource):
"""Get app meta"""
return AppService().get_app_meta(app_model)
class AppInfoApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app infomation"""
return {
'name':app_model.name,
'description':app_model.description
}
api.add_resource(AppParameterApi, '/parameters')
api.add_resource(AppMetaApi, '/meta')
api.add_resource(AppInfoApi, '/info')

View File

@@ -6,6 +6,7 @@ import services
from controllers.service_api import api
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
@@ -27,7 +28,13 @@ class ConversationApi(Resource):
args = parser.parse_args()
try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
return ConversationService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=args['last_id'],
limit=args['limit'],
invoke_from=InvokeFrom.SERVICE_API
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@@ -5,6 +5,7 @@ from werkzeug.exceptions import NotFound
from controllers.web import api
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
@@ -37,7 +38,8 @@ class ConversationListApi(WebApiResource):
user=end_user,
last_id=args['last_id'],
limit=args['limit'],
pinned=pinned
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@@ -5,6 +5,7 @@ from datetime import datetime
from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
@@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file.message_file_parser import MessageFileParser
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -22,6 +24,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
@@ -37,7 +40,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from models.model import Message, MessageAgentThought
from models.model import Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@@ -45,6 +48,7 @@ logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
@@ -72,6 +76,7 @@ class BaseAgentRunner(AppRunner):
"""
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.app_config = app_config
self.model_config = model_config
self.config = config
@@ -118,6 +123,12 @@ class BaseAgentRunner(AppRunner):
else:
self.stream_tool_call = False
# check if model supports vision
if model_schema and ModelFeature.VISION in (model_schema.features or []):
self.files = application_generate_entity.files
else:
self.files = []
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
"""
@@ -227,6 +238,34 @@ class BaseAgentRunner(AppRunner):
return prompt_tool
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
"""
Init tools
"""
tool_instances = {}
prompt_messages_tools = []
for tool in self.app_config.agent.tools if self.app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue
# save tool entity
tool_instances[tool.tool_name] = tool_entity
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# convert dataset tools into ModelRuntime Tool format
for dataset_tool in self.dataset_tools:
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
"""
update prompt message tool
@@ -314,7 +353,7 @@ class BaseAgentRunner(AppRunner):
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, str],
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
@@ -412,15 +451,19 @@ class BaseAgentRunner(AppRunner):
"""
result = []
# check if there is a system message in the beginning of the conversation
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
result.append(prompt_messages[0])
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
messages: list[Message] = db.session.query(Message).filter(
Message.conversation_id == self.message.conversation_id,
).order_by(Message.created_at.asc()).all()
for message in messages:
result.append(UserPromptMessage(content=message.query))
if message.id == self.message.id:
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
@@ -471,3 +514,32 @@ class BaseAgentRunner(AppRunner):
db.session.close()
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
message_file_parser = MessageFileParser(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
)
files = message.message_files
if files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.transform_message_files(
files,
file_extra_config
)
else:
file_objs = []
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
return UserPromptMessage(content=prompt_message_contents)
else:
return UserPromptMessage(content=message.query)

View File

@@ -1,33 +1,36 @@
import json
import re
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Literal, Union
from typing import Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message
from models.model import Message
class CotAgentRunner(BaseAgentRunner):
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ['wenxin']
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
def run(self, conversation: Conversation,
message: Message,
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
@@ -36,9 +39,7 @@ class CotAgentRunner(BaseAgentRunner):
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
agent_scratchpad: list[AgentScratchpadUnit] = []
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
self._init_react_state(query)
# check model mode
if 'Observation' not in app_generate_entity.model_config.stop:
@@ -47,38 +48,19 @@ class CotAgentRunner(BaseAgentRunner):
app_config = self.app_config
# override inputs
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
prompt_messages = self.history_prompt_messages
# convert tools into ModelRuntime Tool format
prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {}
for tool in app_config.agent.tools if app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue
# save tool entity
tool_instances[tool.tool_name] = tool_entity
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# convert dataset tools into ModelRuntime Tool format
for dataset_tool in self.dataset_tools:
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
prompt_messages = self._organize_prompt_messages()
function_call_state = True
llm_usage = {
'usage': None
@@ -103,7 +85,7 @@ class CotAgentRunner(BaseAgentRunner):
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
prompt_messages_tools = []
self._prompt_messages_tools = []
message_file_ids = []
@@ -120,18 +102,8 @@ class CotAgentRunner(BaseAgentRunner):
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt messages
prompt_messages = self._organize_cot_prompt_messages(
mode=app_generate_entity.model_config.mode,
prompt_messages=prompt_messages,
tools=prompt_messages_tools,
agent_scratchpad=agent_scratchpad,
agent_prompt_message=app_config.agent.prompt,
instruction=instruction,
input=query
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
@@ -149,7 +121,7 @@ class CotAgentRunner(BaseAgentRunner):
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = self._handle_stream_react(chunks, usage_dict)
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
@@ -165,30 +137,12 @@ class CotAgentRunner(BaseAgentRunner):
), PublishFrom.APPLICATION_MANAGER)
for chunk in react_chunks:
if isinstance(chunk, dict):
scratchpad.agent_response += json.dumps(chunk)
try:
if scratchpad.action:
raise Exception("")
scratchpad.action_str = json.dumps(chunk)
scratchpad.action = AgentScratchpadUnit.Action(
action_name=chunk['action'],
action_input=chunk['action_input']
)
except:
scratchpad.thought += json.dumps(chunk)
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text
),
usage=None
)
)
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
scratchpad.agent_response += json.dumps(chunk.dict())
scratchpad.action_str = json.dumps(chunk.dict())
scratchpad.action = action
else:
scratchpad.agent_response += chunk
scratchpad.thought += chunk
@@ -206,27 +160,29 @@ class CotAgentRunner(BaseAgentRunner):
)
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
agent_scratchpad.append(scratchpad)
self._agent_scratchpad.append(scratchpad)
# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
else:
usage_dict['usage'] = LLMUsage.empty_usage()
self.save_agent_thought(agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else '',
tool_invoke_meta={},
thought=scratchpad.thought,
observation='',
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict['usage'])
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
observation='',
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict['usage']
)
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
@@ -238,106 +194,43 @@ class CotAgentRunner(BaseAgentRunner):
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
final_answer = scratchpad.action.action_input if \
isinstance(scratchpad.action.action_input, str) else \
json.dumps(scratchpad.action.action_input)
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f'{scratchpad.action.action_input}'
except json.JSONDecodeError:
final_answer = f'{scratchpad.action.action_input}'
else:
function_call_state = True
# action is tool call, invoke tool
tool_call_name = scratchpad.action.action_name
tool_call_args = scratchpad.action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
self.save_agent_thought(
agent_thought=agent_thought,
tool_name='',
tool_input='',
tool_invoke_meta=ToolInvokeMeta.error_instance(
f"there is not a tool named {tool_call_name}"
).to_dict(),
thought=None,
observation={
tool_call_name: answer
},
answer=answer,
messages_ids=[]
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
else:
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
tool_instances=tool_instances,
message_file_ids=message_file_ids
)
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback
)
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta=tool_invoke_meta.to_dict(),
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict['usage']
)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file.id)
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name,
value=message_file.id,
name=save_as)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
message_file_ids = [message_file.id for message_file, _ in message_files]
observation = tool_invoke_response
# save scratchpad
scratchpad.observation = observation
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=tool_call_name,
tool_input={
tool_call_name: tool_call_args
},
tool_invoke_meta={
tool_call_name: tool_invoke_meta.to_dict()
},
thought=None,
observation={
tool_call_name: observation
},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool message
for prompt_tool in prompt_messages_tools:
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
@@ -379,96 +272,63 @@ class CotAgentRunner(BaseAgentRunner):
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
-> Generator[Union[str, dict], None, None]:
def parse_json(json_str):
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
:param tool_instances: tool instances
:return: observation, meta
"""
# action is tool call, invoke tool
tool_call_name = action.action_name
tool_call_args = action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
if isinstance(tool_call_args, str):
try:
return json.loads(json_str.strip())
except:
return json_str
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
yield parse_json(json_text)
code_block_cache = ''
code_block_delimiter_count = 0
in_code_block = False
json_cache = ''
json_quote_count = 0
in_json = False
got_json = False
for response in llm_response:
response = response.delta.message.content
if not isinstance(response, str):
continue
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# stream
index = 0
while index < len(response):
steps = 1
delta = response[index:index+steps]
if delta == '`':
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ''
else:
code_block_cache += delta
code_block_delimiter_count = 0
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback
)
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ''
in_code_block = not in_code_block
code_block_delimiter_count = 0
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
if not in_code_block:
# handle single json
if delta == '{':
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == '}':
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
if json_quote_count == 0:
in_json = False
got_json = True
index += steps
continue
else:
if in_json:
json_cache += delta
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file.id)
if got_json:
got_json = False
yield parse_json(json_cache)
json_cache = ''
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace('`', '')
return tool_invoke_response, tool_invoke_meta
index += steps
if code_block_cache:
yield code_block_cache
if json_cache:
yield parse_json(json_cache)
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(
action_name=action['action'],
action_input=action['action_input']
)
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
@@ -482,15 +342,46 @@ class CotAgentRunner(BaseAgentRunner):
return instruction
def _init_agent_scratchpad(self,
agent_scratchpad: list[AgentScratchpadUnit],
messages: list[PromptMessage]
) -> list[AgentScratchpadUnit]:
def _init_react_state(self, query) -> None:
"""
init agent scratchpad
"""
self._query = query
self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages()
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
"""
message = ''
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
else:
message += f"Thought: {scratchpad.thought}\n\n"
if scratchpad.action_str:
message += f"Action: {scratchpad.action_str}\n\n"
if scratchpad.observation:
message += f"Observation: {scratchpad.observation}\n\n"
return message
def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpad: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
for message in messages:
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
@@ -505,186 +396,29 @@ class CotAgentRunner(BaseAgentRunner):
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
)
except:
pass
agent_scratchpad.append(current_scratchpad)
scratchpad.append(current_scratchpad)
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage):
result.append(message)
if scratchpad:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpad)
))
scratchpad = []
if scratchpad:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpad)
))
return agent_scratchpad
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
agent_prompt_message: AgentPromptEntity,
):
"""
check chain of thought prompt messages, a standard prompt message is like:
Respond to the human as helpfully and accurately as possible.
{{instruction}}
You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid action values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
```
{
"action": $TOOL_NAME,
"action_input": $ACTION_INPUT
}
```
"""
# parse agent prompt message
first_prompt = agent_prompt_message.first_prompt
next_iteration = agent_prompt_message.next_iteration
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
# check instruction, tools, and tool_names slots
if not first_prompt.find("{{instruction}}") >= 0:
raise ValueError("{{instruction}} is required in first_prompt")
if not first_prompt.find("{{tools}}") >= 0:
raise ValueError("{{tools}} is required in first_prompt")
if not first_prompt.find("{{tool_names}}") >= 0:
raise ValueError("{{tool_names}} is required in first_prompt")
if mode == "completion":
if not first_prompt.find("{{query}}") >= 0:
raise ValueError("{{query}} is required in first_prompt")
if not first_prompt.find("{{agent_scratchpad}}") >= 0:
raise ValueError("{{agent_scratchpad}} is required in first_prompt")
if mode == "completion":
if not next_iteration.find("{{observation}}") >= 0:
raise ValueError("{{observation}} is required in next_iteration")
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
convert agent scratchpad list to str
"""
next_iteration = self.app_config.agent.prompt.next_iteration
result = ''
for scratchpad in agent_scratchpad:
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
return result
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
agent_scratchpad: list[AgentScratchpadUnit],
agent_prompt_message: AgentPromptEntity,
instruction: str,
input: str,
) -> list[PromptMessage]:
"""
organize chain of thought prompt messages, a standard prompt message is like:
Respond to the human as helpfully and accurately as possible.
{{instruction}}
You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid action values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $ACTION_INPUT
}}}}
```
"""
self._check_cot_prompt_messages(mode, agent_prompt_message)
# parse agent prompt message
first_prompt = agent_prompt_message.first_prompt
# parse tools
tools_str = self._jsonify_tool_prompt_messages(tools)
# parse tools name
tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
# get system message
system_message = first_prompt.replace("{{instruction}}", instruction) \
.replace("{{tools}}", tools_str) \
.replace("{{tool_names}}", tool_names)
# organize prompt messages
if mode == "chat":
# override system message
overridden = False
prompt_messages = prompt_messages.copy()
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message.content = system_message
overridden = True
break
# convert tool prompt messages to user prompt messages
for idx, prompt_message in enumerate(prompt_messages):
if isinstance(prompt_message, ToolPromptMessage):
prompt_messages[idx] = UserPromptMessage(
content=prompt_message.content
)
if not overridden:
prompt_messages.insert(0, SystemPromptMessage(
content=system_message,
))
# add assistant message
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
prompt_messages.append(AssistantPromptMessage(
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
))
# add user message
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
prompt_messages.append(UserPromptMessage(
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
))
self._is_first_iteration = False
return prompt_messages
elif mode == "completion":
# parse agent scratchpad
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
self._is_first_iteration = False
# parse prompt messages
return [UserPromptMessage(
content=first_prompt.replace("{{instruction}}", instruction)
.replace("{{tools}}", tools_str)
.replace("{{tool_names}}", tool_names)
.replace("{{query}}", input)
.replace("{{agent_scratchpad}}", agent_scratchpad_str),
)]
else:
raise ValueError(f"mode {mode} is not supported")
def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
"""
jsonify tool prompt messages
"""
tools = jsonable_encoder(tools)
try:
return json.dumps(tools, ensure_ascii=False)
except json.JSONDecodeError:
return json.dumps(tools)
return result

View File

@@ -0,0 +1,71 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
class CotChatAgentRunner(CotAgentRunner):
def _organize_system_prompt(self) -> SystemPromptMessage:
"""
Organize system prompt
"""
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt \
.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
return SystemPromptMessage(content=system_prompt)
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
# organize historic prompt messages
historic_messages = self._historic_prompt_messages
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content='')
for unit in agent_scratchpad:
if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_message.content += f"Observation: {unit.observation}\n\n"
assistant_messages = [assistant_message]
# query messages
query_messages = UserPromptMessage(content=self._query)
if assistant_messages:
messages = [
system_message,
*historic_messages,
query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
]
else:
messages = [system_message, *historic_messages, query_messages]
# join all messages
return messages

View File

@@ -0,0 +1,69 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
from core.model_runtime.utils.encoders import jsonable_encoder
class CotCompletionAgentRunner(CotAgentRunner):
def _organize_instruction_prompt(self) -> str:
"""
Organize instruction prompt
"""
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
return system_prompt
def _organize_historic_prompt(self) -> str:
"""
Organize historic prompt
"""
historic_prompt_messages = self._historic_prompt_messages
historic_prompt = ""
for message in historic_prompt_messages:
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
historic_prompt += message.content + "\n\n"
return historic_prompt
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages
"""
# organize system prompt
system_prompt = self._organize_instruction_prompt()
# organize historic prompt messages
historic_prompt = self._organize_historic_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ''
for unit in agent_scratchpad:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
else:
assistant_prompt += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_prompt += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_prompt += f"Observation: {unit.observation}\n\n"
# query messages
query_prompt = f"Question: {self._query}"
# join all messages
prompt = system_prompt \
.replace("{{historic_messages}}", historic_prompt) \
.replace("{{agent_scratchpad}}", assistant_prompt) \
.replace("{{query}}", query_prompt)
return [UserPromptMessage(content=prompt)]

View File

@@ -34,12 +34,29 @@ class AgentScratchpadUnit(BaseModel):
action_name: str
action_input: Union[dict, str]
def to_dict(self) -> dict:
"""
Convert to dictionary.
"""
return {
'action': self.action_name,
'action_input': self.action_input,
}
agent_response: Optional[str] = None
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
def is_final(self) -> bool:
"""
Check if the scratchpad unit is final.
"""
return self.action is None or (
'final' in self.action.action_name.lower() and
'answer' in self.action.action_name.lower()
)
class AgentEntity(BaseModel):
"""

View File

@@ -1,6 +1,7 @@
import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
@@ -10,21 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message, MessageAgentThought
from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, conversation: Conversation,
message: Message,
query: str,
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
@@ -35,40 +36,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_template = app_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages
prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template,
query=query,
prompt_messages=prompt_messages
)
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
prompt_messages = self._organize_user_query(query, prompt_messages)
# convert tools into ModelRuntime Tool format
prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {}
for tool in app_config.agent.tools if app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue
# save tool entity
tool_instances[tool.tool_name] = tool_entity
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# convert dataset tools into ModelRuntime Tool format
for dataset_tool in self.dataset_tools:
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
tool_instances, prompt_messages_tools = self._init_prompt_tools()
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call
function_call_state = True
agent_thoughts: list[MessageAgentThought] = []
llm_usage = {
'usage': None
}
@@ -207,19 +185,25 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
)
assistant_message = AssistantPromptMessage(
content='',
tool_calls=[]
)
if tool_calls:
prompt_messages.append(AssistantPromptMessage(
content='',
name='',
tool_calls=[AssistantPromptMessage.ToolCall(
assistant_message.tool_calls=[
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls]
))
) for tool_call in tool_calls
]
else:
assistant_message.content = response
prompt_messages.append(assistant_message)
# save thought
self.save_agent_thought(
@@ -239,12 +223,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
final_answer += response + '\n'
# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
@@ -287,9 +265,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
}
tool_responses.append(tool_response)
prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template,
query=None,
prompt_messages = self._organize_assistant_message(
tool_call_id=tool_call_id,
tool_call_name=tool_call_name,
tool_response=tool_response['tool_response'],
@@ -324,6 +300,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
@@ -386,29 +364,68 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def organize_prompt_messages(self, prompt_template: str,
query: str = None,
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize prompt messages
Initialize system message
"""
if not prompt_messages:
prompt_messages = [
if not prompt_messages and prompt_template:
return [
SystemPromptMessage(content=prompt_template),
UserPromptMessage(content=query),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
if tool_response:
prompt_messages = prompt_messages.copy()
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
)
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize assistant message
"""
prompt_messages = deepcopy(prompt_messages)
if tool_response is not None:
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([
content.data if content.type == PromptMessageContentType.TEXT else
'[image]' if content.type == PromptMessageContentType.IMAGE else
'[file]'
for content in prompt_message.content
])
return prompt_messages

View File

@@ -0,0 +1,183 @@
import json
import re
from collections.abc import Generator
from typing import Union
from core.agent.entities import AgentScratchpadUnit
from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None]) -> \
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str):
try:
action = json.loads(json_str)
action_name = None
action_input = None
for key, value in action.items():
if 'input' in key.lower():
action_input = value
else:
action_name = value
if action_name is not None and action_input is not None:
return AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
else:
return json_str or ''
except:
return json_str or ''
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
yield parse_action(json_text)
code_block_cache = ''
code_block_delimiter_count = 0
in_code_block = False
json_cache = ''
json_quote_count = 0
in_json = False
got_json = False
action_cache = ''
action_str = 'action:'
action_idx = 0
thought_cache = ''
thought_str = 'thought:'
thought_idx = 0
for response in llm_response:
response = response.delta.message.content
if not isinstance(response, str):
continue
# stream
index = 0
while index < len(response):
steps = 1
delta = response[index:index+steps]
last_character = response[index-1] if index > 0 else ''
if delta == '`':
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ''
else:
code_block_cache += delta
code_block_delimiter_count = 0
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ['\n', ' ', '']:
index += steps
yield delta
continue
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ''
action_idx = 0
index += steps
continue
elif delta.lower() == action_str[action_idx] and action_idx > 0:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ''
action_idx = 0
index += steps
continue
else:
if action_cache:
yield action_cache
action_cache = ''
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ['\n', ' ', '']:
index += steps
yield delta
continue
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ''
thought_idx = 0
index += steps
continue
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ''
thought_idx = 0
index += steps
continue
else:
if thought_cache:
yield thought_cache
thought_cache = ''
thought_idx = 0
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ''
in_code_block = not in_code_block
code_block_delimiter_count = 0
if not in_code_block:
# handle single json
if delta == '{':
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == '}':
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
if json_quote_count == 0:
in_json = False
got_json = True
index += steps
continue
else:
if in_json:
json_cache += delta
if got_json:
got_json = False
yield parse_action(json_cache)
json_cache = ''
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace('`', '')
index += steps
if code_block_cache:
yield code_block_cache
if json_cache:
yield parse_action(json_cache)

View File

@@ -1,7 +1,8 @@
import logging
from typing import cast
from core.agent.cot_agent_runner import CotAgentRunner
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
@@ -11,8 +12,8 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, Mo
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
@@ -207,48 +208,40 @@ class AgentChatAppRunner(AppRunner):
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = CotAgentRunner(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
app_config=app_config,
model_config=application_generate_entity.model_config,
config=agent_entity,
queue_manager=queue_manager,
message=message,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = assistant_cot_runner.run(
conversation=conversation,
message=message,
query=query,
inputs=inputs,
)
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
assistant_fc_runner = FunctionCallAgentRunner(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
app_config=app_config,
model_config=application_generate_entity.model_config,
config=agent_entity,
queue_manager=queue_manager,
message=message,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = assistant_fc_runner.run(
conversation=conversation,
message=message,
query=query,
)
runner_cls = FunctionCallAgentRunner
else:
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation,
app_config=app_config,
model_config=application_generate_entity.model_config,
config=agent_entity,
queue_manager=queue_manager,
message=message,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = runner.run(
message=message,
query=query,
inputs=inputs,
)
# handle invoke result
self._handle_invoke_result(

View File

@@ -156,6 +156,8 @@ class ChatAppRunner(AppRunner):
dataset_retrieval = DatasetRetrieval()
context = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_config,
config=app_config.dataset,

View File

@@ -116,6 +116,8 @@ class CompletionAppRunner(AppRunner):
dataset_retrieval = DatasetRetrieval()
context = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_config,
config=dataset_config,

View File

@@ -122,7 +122,7 @@ class MessageEndStreamResponse(StreamResponse):
"""
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: Optional[dict] = None
metadata: dict = {}
class MessageFileStreamResponse(StreamResponse):

View File

@@ -1,12 +1,32 @@
import os
from typing import Any, Optional, Union
from typing import Any, Optional, TextIO, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from pydantic import BaseModel
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file
class DifyAgentCallbackHandler(BaseModel):
"""Callback Handler that prints to std out."""
color: Optional[str] = ''
current_loop = 1

View File

@@ -41,7 +41,8 @@ class CacheEmbedding(Embeddings):
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
model_schema = model_type_instance.get_model_schema(self._model_instance.model,
self._model_instance.credentials)
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
for i in range(0, len(embedding_queue_texts), max_chunks):
@@ -61,17 +62,20 @@ class CacheEmbedding(Embeddings):
except Exception as e:
logging.exception('Failed transform embedding: ', e)
cache_embeddings = []
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider)
embedding_cache.set_embedding(embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
db.session.commit()
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider)
embedding_cache.set_embedding(embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
db.session.commit()
except IntegrityError:
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.error('Failed to embed documents: ', ex)

View File

@@ -29,16 +29,16 @@ class NodeJsTemplateTransformer(TemplateTransformer):
:param inputs: inputs
:return:
"""
# transform inputs to json string
inputs_str = json.dumps(inputs, indent=4)
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
# replace code and inputs
runner = NODEJS_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', inputs_str)
return runner, NODEJS_PRELOAD
@classmethod
def transform_response(cls, response: str) -> dict:
"""

View File

@@ -62,10 +62,10 @@ class Jinja2TemplateTransformer(TemplateTransformer):
# transform jinja2 template to python code
runner = PYTHON_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4))
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False))
return runner, JINJA2_PRELOAD
@classmethod
def transform_response(cls, response: str) -> dict:
"""
@@ -81,4 +81,4 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return {
'result': result
}
}

View File

@@ -34,7 +34,7 @@ class PythonTemplateTransformer(TemplateTransformer):
"""
# transform inputs to json string
inputs_str = json.dumps(inputs, indent=4)
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
# replace code and inputs
runner = PYTHON_RUNNER.replace('{{code}}', code)

View File

@@ -19,6 +19,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -657,18 +658,25 @@ class IndexingRunner:
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))
# create keyword index
create_keyword_thread = threading.Thread(target=self._process_keyword_index,
args=(current_app._get_current_object(),
dataset.id, dataset_document.id, documents))
create_keyword_thread.start()
if dataset.indexing_technique == 'high_quality':
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))
for future in futures:
tokens += future.result()
for future in futures:
tokens += future.result()
create_keyword_thread.join()
indexing_end_at = time.perf_counter()
# update document status to completed
@@ -682,6 +690,27 @@ class IndexingRunner:
}
)
def _process_keyword_index(self, flask_app, dataset_id, document_id, documents):
with flask_app.app_context():
dataset = Dataset.query.filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != 'high_quality':
document_ids = [document.metadata['doc_id'] for document in documents]
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.utcnow()
})
db.session.commit()
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
embedding_model_instance, embedding_model_type_instance):
with flask_app.app_context():
@@ -700,7 +729,7 @@ class IndexingRunner:
)
# load index
index_processor.load(dataset, chunk_documents)
index_processor.load(dataset, chunk_documents, with_keywords=False)
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(

View File

@@ -3,6 +3,7 @@ from core.file.message_file_parser import MessageFileParser
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
TextPromptMessageContent,
@@ -124,7 +125,17 @@ class TokenBufferMemory:
else:
continue
message = f"{role}: {m.content}"
string_messages.append(message)
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@@ -99,6 +99,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo-2024-04-09
value: gpt-4-turbo-2024-04-09
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-0125-preview
value: gpt-4-0125-preview

View File

@@ -74,7 +74,7 @@ provider_credential_schema:
label:
en_US: Available Model Name
zh_Hans: 可用模型名称
type: secret-input
type: text-input
placeholder:
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如amazon.titan-text-lite-v1)

View File

@@ -2,8 +2,6 @@ model: amazon.titan-text-express-v1
label:
en_US: Titan Text G1 - Express
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192

View File

@@ -2,8 +2,6 @@ model: amazon.titan-text-lite-v1
label:
en_US: Titan Text G1 - Lite
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096

View File

@@ -50,3 +50,4 @@ pricing:
output: '0.024'
unit: '0.001'
currency: USD
deprecated: true

View File

@@ -22,7 +22,7 @@ parameter_rules:
min: 0
max: 500
default: 0
- name: max_tokens_to_sample
- name: max_tokens
use_template: max_tokens
required: true
default: 4096

View File

@@ -8,9 +8,9 @@ model_properties:
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
- name: p
use_template: top_p
- name: top_k
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
@@ -19,7 +19,7 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_tokens
use_template: max_tokens
required: true
default: 4096

View File

@@ -402,25 +402,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param credentials: model credentials
:return:
"""
if "anthropic.claude-3" in model:
try:
self._invoke_claude(model=model,
credentials=credentials,
prompt_messages=[{"role": "user", "content": "ping"}],
model_parameters={},
stop=None,
stream=False)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
required_params = {}
if "anthropic" in model:
required_params = {
"max_tokens": 32,
}
elif "ai21" in model:
# ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again.
required_params = {
"temperature": 0.7,
"topP": 0.9,
"maxTokens": 32,
}
try:
ping_message = UserPromptMessage(content="ping")
self._generate(model=model,
self._invoke(model=model,
credentials=credentials,
prompt_messages=[ping_message],
model_parameters={},
model_parameters=required_params,
stream=False)
except ClientError as ex:
@@ -503,7 +503,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if model_prefix == "amazon":
payload["textGenerationConfig"] = { **model_parameters }
payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else [])
payload["textGenerationConfig"]["stopSequences"] = ["User:"]
payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
@@ -513,10 +513,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
payload["maxTokens"] = model_parameters.get("maxTokens")
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
# jurassic models only support a single stop sequence
if stop:
payload["stopSequences"] = stop[0]
if model_parameters.get("presencePenalty"):
payload["presencePenalty"] = {model_parameters.get("presencePenalty")}
if model_parameters.get("frequencyPenalty"):

View File

@@ -1,3 +1,5 @@
- command-r
- command-r-plus
- command-chat
- command-light-chat
- command-nightly-chat

View File

@@ -31,7 +31,7 @@ parameter_rules:
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
- name: preamble_override
label:

View File

@@ -31,7 +31,7 @@ parameter_rules:
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
- name: preamble_override
label:

View File

@@ -31,7 +31,7 @@ parameter_rules:
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
- name: preamble_override
label:

View File

@@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
pricing:
input: '0.3'

View File

@@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
pricing:
input: '0.3'

View File

@@ -31,7 +31,7 @@ parameter_rules:
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
- name: preamble_override
label:

View File

@@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
pricing:
input: '1.0'

View File

@@ -0,0 +1,45 @@
model: command-r-plus
label:
en_US: command-r-plus
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 1024
max: 4096
pricing:
input: '3'
output: '15'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,45 @@
model: command-r
label:
en_US: command-r
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 1024
max: 4096
pricing:
input: '0.5'
output: '1.5'
unit: '0.000001'
currency: USD

View File

@@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
pricing:
input: '1.0'

View File

@@ -1,20 +1,38 @@
import json
import logging
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import Optional, Union, cast
import cohere
from cohere.responses import Chat, Generations
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
from cohere.responses.generation import StreamingGenerations, StreamingText
from cohere import (
ChatMessage,
ChatStreamRequestToolResultsItem,
GenerateStreamedResponse,
GenerateStreamedResponse_StreamEnd,
GenerateStreamedResponse_StreamError,
GenerateStreamedResponse_TextGeneration,
Generation,
NonStreamedChatResponse,
StreamedChatResponse,
StreamedChatResponse_StreamEnd,
StreamedChatResponse_TextGeneration,
StreamedChatResponse_ToolCallsGeneration,
Tool,
ToolCall,
ToolParameterDefinitionsValue,
)
from cohere.core import RequestOptions
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
@@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
@@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
if stop:
model_parameters['end_sequences'] = stop
response = client.generate(
prompt=prompt_messages[0].content,
model=model,
stream=stream,
**model_parameters,
)
if stream:
response = client.generate_stream(
prompt=prompt_messages[0].content,
model=model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
else:
response = client.generate(
prompt=prompt_messages[0].content,
model=model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_generate_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
prompt_messages: list[PromptMessage]) \
-> LLMResult:
"""
@@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
)
# calculate num tokens
prompt_tokens = response.meta['billed_units']['input_tokens']
completion_tokens = response.meta['billed_units']['output_tokens']
prompt_tokens = int(response.meta.billed_units.input_tokens)
completion_tokens = int(response.meta.billed_units.output_tokens)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
index = 1
full_assistant_content = ''
for chunk in response:
if isinstance(chunk, StreamingText):
chunk = cast(StreamingText, chunk)
if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
text = chunk.text
if text is None:
@@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
)
index += 1
elif chunk is None:
elif isinstance(chunk, GenerateStreamedResponse_StreamEnd):
chunk = cast(GenerateStreamedResponse_StreamEnd, chunk)
# calculate num tokens
prompt_tokens = response.meta['billed_units']['input_tokens']
completion_tokens = response.meta['billed_units']['output_tokens']
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
completion_tokens = self._num_tokens_from_messages(
model,
credentials,
[AssistantPromptMessage(content=full_assistant_content)]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
finish_reason=response.finish_reason,
finish_reason=chunk.finish_reason,
usage=usage
)
)
break
elif isinstance(chunk, GenerateStreamedResponse_StreamError):
chunk = cast(GenerateStreamedResponse_StreamError, chunk)
raise InvokeBadRequestError(chunk.err)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
@@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
@@ -282,32 +319,49 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# initialize client
client = cohere.Client(credentials.get('api_key'))
if user:
model_parameters['user_name'] = user
if stop:
model_parameters['stop_sequences'] = stop
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
if tools:
if len(tools) == 1:
raise ValueError("Cohere tool call requires at least two tools to be specified.")
model_parameters['tools'] = self._convert_tools(tools)
message, chat_histories, tool_results \
= self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
if tool_results:
model_parameters['tool_results'] = tool_results
# chat model
real_model = model
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
real_model = model.removesuffix('-chat')
response = client.chat(
message=message,
chat_history=chat_histories,
model=real_model,
stream=stream,
return_preamble=True,
**model_parameters,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
response = client.chat_stream(
message=message,
chat_history=chat_histories,
model=real_model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
else:
response = client.chat(
message=message,
chat_history=chat_histories,
model=real_model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
prompt_messages: list[PromptMessage]) \
-> LLMResult:
"""
Handle llm chat response
@@ -316,14 +370,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param stop: stop words
:return: llm response
"""
assistant_text = response.text
tool_calls = []
if response.tool_calls:
for cohere_tool_call in response.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=cohere_tool_call.name,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=cohere_tool_call.name,
arguments=json.dumps(cohere_tool_call.parameters)
)
)
tool_calls.append(tool_call)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
content=assistant_text,
tool_calls=tool_calls
)
# calculate num tokens
@@ -333,44 +400,38 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
if stop:
# enforce stop tokens
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
# transform response
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.preamble
usage=usage
)
return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) -> Generator:
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Iterator[StreamedChatResponse],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm chat stream response
:param model: model name
:param response: response
:param prompt_messages: prompt messages
:param stop: stop words
:return: llm response chunk generator
"""
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
preamble: Optional[str] = None) -> LLMResultChunk:
def final_response(full_text: str,
tool_calls: list[AssistantPromptMessage.ToolCall],
index: int,
finish_reason: Optional[str] = None) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
full_assistant_prompt_message = AssistantPromptMessage(
content=full_text
content=full_text,
tool_calls=tool_calls
)
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
@@ -380,10 +441,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=preamble,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
message=AssistantPromptMessage(content='', tool_calls=tool_calls),
finish_reason=finish_reason,
usage=usage
)
@@ -391,9 +451,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
index = 1
full_assistant_content = ''
tool_calls = []
for chunk in response:
if isinstance(chunk, StreamTextGeneration):
chunk = cast(StreamTextGeneration, chunk)
if isinstance(chunk, StreamedChatResponse_TextGeneration):
chunk = cast(StreamedChatResponse_TextGeneration, chunk)
text = chunk.text
if text is None:
@@ -404,12 +465,6 @@ class CohereLargeLanguageModel(LargeLanguageModel):
content=text
)
# stop
# notice: This logic can only cover few stop scenarios
if stop and text in stop:
yield final_response(full_assistant_content, index, 'stop')
break
full_assistant_content += text
yield LLMResultChunk(
@@ -422,39 +477,96 @@ class CohereLargeLanguageModel(LargeLanguageModel):
)
index += 1
elif isinstance(chunk, StreamEnd):
chunk = cast(StreamEnd, chunk)
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk)
if chunk.tool_calls:
for cohere_tool_call in chunk.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=cohere_tool_call.name,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=cohere_tool_call.name,
arguments=json.dumps(cohere_tool_call.parameters)
)
)
tool_calls.append(tool_call)
elif isinstance(chunk, StreamedChatResponse_StreamEnd):
chunk = cast(StreamedChatResponse_StreamEnd, chunk)
yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
index += 1
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
-> tuple[str, list[dict]]:
-> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
"""
Convert prompt messages to message and chat histories
:param prompt_messages: prompt messages
:return:
"""
chat_histories = []
latest_tool_call_n_outputs = []
for prompt_message in prompt_messages:
chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
if prompt_message.role == PromptMessageRole.ASSISTANT:
prompt_message = cast(AssistantPromptMessage, prompt_message)
if prompt_message.tool_calls:
for tool_call in prompt_message.tool_calls:
latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call.function.name,
parameters=json.loads(tool_call.function.arguments)
),
outputs=[]
))
else:
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
if cohere_prompt_message:
chat_histories.append(cohere_prompt_message)
elif prompt_message.role == PromptMessageRole.TOOL:
prompt_message = cast(ToolPromptMessage, prompt_message)
if latest_tool_call_n_outputs:
i = 0
for tool_call_n_outputs in latest_tool_call_n_outputs:
if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call_n_outputs.call.name,
parameters=tool_call_n_outputs.call.parameters
),
outputs=[{
"result": prompt_message.content
}]
)
break
i += 1
else:
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
if cohere_prompt_message:
chat_histories.append(cohere_prompt_message)
if latest_tool_call_n_outputs:
new_latest_tool_call_n_outputs = []
for tool_call_n_outputs in latest_tool_call_n_outputs:
if tool_call_n_outputs.outputs:
new_latest_tool_call_n_outputs.append(tool_call_n_outputs)
latest_tool_call_n_outputs = new_latest_tool_call_n_outputs
# get latest message from chat histories and pop it
if len(chat_histories) > 0:
latest_message = chat_histories.pop()
message = latest_message['message']
message = latest_message.message
else:
raise ValueError('Prompt messages is empty')
return message, chat_histories
return message, chat_histories, latest_tool_call_n_outputs
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]:
"""
Convert PromptMessage to dict for Cohere model
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "USER", "message": message.content}
chat_message = ChatMessage(role="USER", message=message.content)
else:
sub_message_text = ''
for message_content in message.content:
@@ -462,20 +574,57 @@ class CohereLargeLanguageModel(LargeLanguageModel):
message_content = cast(TextPromptMessageContent, message_content)
sub_message_text += message_content.data
message_dict = {"role": "USER", "message": sub_message_text}
chat_message = ChatMessage(role="USER", message=sub_message_text)
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "CHATBOT", "message": message.content}
if not message.content:
return None
chat_message = ChatMessage(role="CHATBOT", message=message.content)
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "USER", "message": message.content}
chat_message = ChatMessage(role="USER", message=message.content)
elif isinstance(message, ToolPromptMessage):
return None
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["user_name"] = message.name
return chat_message
return message_dict
def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]:
"""
Convert tools to Cohere model
"""
cohere_tools = []
for tool in tools:
properties = tool.parameters['properties']
required_properties = tool.parameters['required']
parameter_definitions = {}
for p_key, p_val in properties.items():
required = False
if property in required_properties:
required = True
desc = p_val['description']
if 'enum' in p_val:
desc += (f"; Only accepts one of the following predefined options: "
f"[{', '.join(p_val['enum'])}]")
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
description=desc,
type=p_val['type'],
required=required
)
cohere_tool = Tool(
name=tool.name,
description=tool.description,
parameter_definitions=parameter_definitions
)
cohere_tools.append(cohere_tool)
return cohere_tools
def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
"""
@@ -494,12 +643,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model=model
)
return response.length
return len(response.tokens)
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
"""Calculate num tokens Cohere model."""
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
message_strs = [f"{message['role']}: {message['message']}" for message in messages]
calc_messages = []
for message in messages:
cohere_message = self._convert_prompt_message_to_dict(message)
if cohere_message:
calc_messages.append(cohere_message)
message_strs = [f"{message.role}: {message.message}" for message in calc_messages]
message_str = "\n".join(message_strs)
real_model = model
@@ -565,13 +718,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
}

View File

@@ -1,6 +1,7 @@
from typing import Optional
import cohere
from cohere.core import RequestOptions
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
@@ -44,19 +45,21 @@ class CohereRerankModel(RerankModel):
# initialize client
client = cohere.Client(credentials.get('api_key'))
results = client.rerank(
response = client.rerank(
query=query,
documents=docs,
model=model,
top_n=top_n
top_n=top_n,
return_documents=True,
request_options=RequestOptions(max_retries=0)
)
rerank_documents = []
for idx, result in enumerate(results):
for idx, result in enumerate(response.results):
# format document
rerank_document = RerankDocument(
index=result.index,
text=result.document['text'],
text=result.document.text,
score=result.relevance_score,
)
@@ -108,13 +111,21 @@ class CohereRerankModel(RerankModel):
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError,
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
}

View File

@@ -3,7 +3,7 @@ from typing import Optional
import cohere
import numpy as np
from cohere.responses import Tokens
from cohere.core import RequestOptions
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -52,8 +52,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
text=text
)
for j in range(0, tokenize_response.length, context_size):
tokens += [tokenize_response.token_strings[j: j + context_size]]
for j in range(0, len(tokenize_response), context_size):
tokens += [tokenize_response[j: j + context_size]]
indices += [i]
batched_embeddings = []
@@ -127,9 +127,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
except Exception as e:
raise self._transform_invoke_error(e)
return response.length
return len(response)
def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens:
def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]:
"""
Tokenize text
:param model: model name
@@ -138,17 +138,19 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
if not text:
return Tokens([], [], {})
return []
# initialize client
client = cohere.Client(credentials.get('api_key'))
response = client.tokenize(
text=text,
model=model
model=model,
offline=False,
request_options=RequestOptions(max_retries=0)
)
return response
return response.token_strings
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@@ -184,10 +186,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
response = client.embed(
texts=texts,
model=model,
input_type='search_document' if len(texts) > 1 else 'search_query'
input_type='search_document' if len(texts) > 1 else 'search_query',
request_options=RequestOptions(max_retries=1)
)
return response.embeddings, response.meta['billed_units']['input_tokens']
return response.embeddings, int(response.meta.billed_units.input_tokens)
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
@@ -231,13 +234,21 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
}

View File

@@ -0,0 +1,37 @@
model: gemini-1.5-pro-latest
label:
en_US: Gemini 1.5 Pro
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -1,8 +1,31 @@
import json
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
@@ -13,6 +36,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
@@ -20,7 +44,286 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get('function_calling_type') == 'tool_call'
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name='temperature',
use_template='temperature',
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
type=ParameterType.FLOAT,
),
ParameterRule(
name='max_tokens',
use_template='max_tokens',
default=512,
min=1,
max=int(credentials.get('max_tokens', 4096)),
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
type=ParameterType.INT,
),
ParameterRule(
name='top_p',
use_template='top_p',
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
type=ParameterType.FLOAT,
),
]
)
def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)
if model_schema and set([
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
]).intersection(model_schema.features or []):
credentials['function_calling_type'] = 'tool_call'
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI API format
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = []
for function_call in message.tool_calls:
message_dict["tool_calls"].append({
"id": function_call.id,
"type": function_call.type,
"function": {
"name": function_call.function.name,
"arguments": function_call.function.arguments
}
})
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call["id"] if response_tool_call.get("id") else "",
type=response_tool_call["type"] if response_tool_call.get("type") else "",
function=function
)
tool_calls.append(tool_call)
return tool_calls
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ''
chunk_index = 0
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
-> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=message,
finish_reason=finish_reason,
usage=usage
)
)
tools_calls: list[AssistantPromptMessage.ToolCall] = []
finish_reason = "Unknown"
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id='',
type='',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
if chunk:
# ignore sse comments
if chunk.startswith(':'):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
break
if not chunk_json or len(chunk_json['choices']) == 0:
continue
choice = chunk_json['choices'][0]
finish_reason = chunk_json['choices'][0].get('finish_reason')
chunk_index += 1
if 'delta' in choice:
delta = choice['delta']
delta_content = delta.get('content')
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta_content
elif 'text' in choice:
choice_text = choice.get('text', '')
if choice_text == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
# check payload indicator for completion
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
)
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(
tool_calls=tools_calls,
content=""
),
)
)
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason
)

View File

@@ -20,6 +20,7 @@ supported_model_types:
- llm
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
@@ -30,3 +31,51 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '4096'
type: text-input
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not supported
zh_Hans: 不支持
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call

View File

@@ -1,4 +1,6 @@
- gpt-4
- gpt-4-turbo
- gpt-4-turbo-2024-04-09
- gpt-4-turbo-preview
- gpt-4-32k
- gpt-4-1106-preview

View File

@@ -0,0 +1,57 @@
model: gpt-4-turbo-2024-04-09
label:
zh_Hans: gpt-4-turbo-2024-04-09
en_US: gpt-4-turbo-2024-04-09
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
- name: seed
label:
zh_Hans: 种子
en_US: Seed
type: int
help:
zh_Hans: 如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint
响应参数来监视变化。
en_US: If specified, model will make a best effort to sample deterministically,
such that repeated requests with the same seed and parameters should return
the same result. Determinism is not guaranteed, and you should refer to the
system_fingerprint response parameter to monitor changes in the backend.
required: false
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.01'
output: '0.03'
unit: '0.001'
currency: USD

View File

@@ -0,0 +1,57 @@
model: gpt-4-turbo
label:
zh_Hans: gpt-4-turbo
en_US: gpt-4-turbo
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
- name: seed
label:
zh_Hans: 种子
en_US: Seed
type: int
help:
zh_Hans: 如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint
响应参数来监视变化。
en_US: If specified, model will make a best effort to sample deterministically,
such that repeated requests with the same seed and parameters should return
the same result. Determinism is not guaranteed, and you should refer to the
system_fingerprint response parameter to monitor changes in the backend.
required: false
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.01'
output: '0.03'
unit: '0.001'
currency: USD

View File

@@ -547,6 +547,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if user:
extra_model_kwargs['user'] = user
# clear illegal prompt messages
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# chat model
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
@@ -757,6 +760,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
return tool_call
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Clear illegal prompt messages for OpenAI API
:param model: model name
:param prompt_messages: prompt messages
:return: cleaned prompt messages
"""
checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09']
if model in checklist:
# count how many user messages are there
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
if user_message_count > 1:
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([
item.data if item.type == PromptMessageContentType.TEXT else
'[IMAGE]' if item.type == PromptMessageContentType.IMAGE else ''
for item in prompt_message.content
])
return prompt_messages
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI API

View File

@@ -167,23 +167,27 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"""
generate custom model entities from credentials
"""
support_function_call = False
features = []
function_calling_type = credentials.get('function_calling_type', 'no_call')
if function_calling_type == 'function_call':
features = [ModelFeature.TOOL_CALL]
support_function_call = True
features.append(ModelFeature.TOOL_CALL)
endpoint_url = credentials["endpoint_url"]
# if not endpoint_url.endswith('/'):
# endpoint_url += '/'
# if 'https://api.openai.com/v1/' == endpoint_url:
# features = [ModelFeature.STREAM_TOOL_CALL]
# features.append(ModelFeature.STREAM_TOOL_CALL)
vision_support = credentials.get('vision_support', 'not_support')
if vision_support == 'support':
features.append(ModelFeature.VISION)
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features if support_function_call else [],
features=features,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
ModelPropertyKey.MODE: credentials.get('mode'),
@@ -378,13 +382,41 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
delimiter = codecs.decode(delimiter, "unicode_escape")
tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_call_id: str):
tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id='',
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name='',
arguments=''
)
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.id)
# update tool call
tool_call.id = new_tool_call.id
tool_call.type = new_tool_call.type
tool_call.function.name = new_tool_call.function.name
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
if chunk:
# ignore sse comments
if chunk.startswith(':'):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
@@ -405,8 +437,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if 'delta' in choice:
delta = choice['delta']
delta_content = delta.get('content')
if delta_content is None or delta_content == '':
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
@@ -414,6 +444,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == '':
continue
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
@@ -437,6 +472,18 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
# check payload indicator for completion
if finish_reason is not None:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(
tool_calls=tools_calls,
),
finish_reason=finish_reason
)
)
yield create_final_llm_result_chunk(
index=chunk_index,
message=assistant_prompt_message,
@@ -573,7 +620,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return message_dict
def _num_tokens_from_string(self, model: str, text: str,
def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Approximate num tokens for model with gpt2 tokenizer.
@@ -583,7 +630,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
:param tools: tools for tool calling
:return: number of tokens
"""
num_tokens = self._get_num_tokens_by_gpt2(text)
if isinstance(text, str):
full_text = text
else:
full_text = ''
for message_content in text:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
full_text += message_content.data
num_tokens = self._get_num_tokens_by_gpt2(full_text)
if tools:
num_tokens += self._num_tokens_for_tools(tools)
@@ -735,4 +791,4 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
function=function
)
return tool_call
return tool_call

View File

@@ -97,6 +97,25 @@ model_credential_schema:
label:
en_US: Not Support
zh_Hans: 不支持
- variable: vision_support
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: Vision 支持
en_US: Vision Support
type: select
required: false
default: no_support
options:
- value: support
label:
en_US: Support
zh_Hans: 支持
- value: no_support
label:
en_US: Not Support
zh_Hans: 不支持
- variable: stream_mode_delimiter
label:
zh_Hans: 流模式返回结果的分隔符

View File

@@ -1,6 +1,6 @@
model: ernie-3.5-8k
model: ernie-3.5-4k-0205
label:
en_US: Ernie-3.5-8K
en_US: Ernie-3.5-4k-0205
model_type: llm
features:
- agent-thought

View File

@@ -232,8 +232,8 @@ class SimplePromptTransform(PromptTransform):
)
),
max_token_limit=rest_tokens,
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
)
# get prompt

View File

@@ -24,56 +24,67 @@ class Jieba(BaseKeyword):
self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)
return self
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(text.page_content,
self._config.max_keywords_per_chunk)
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
ids = [segment.index_node_id for segment in segments]
ids = [segment.index_node_id for segment in segments]
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)
def search(
self, query: str,
@@ -106,13 +117,15 @@ class Jieba(BaseKeyword):
return documents
def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
storage.delete(file_key)
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
@@ -135,33 +148,31 @@ class Jieba(BaseKeyword):
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
def _get_dataset_keyword_table(self) -> Optional[dict]:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=20):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
else:
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
db.session.add(dataset_keyword_table)
db.session.commit()
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
else:
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
db.session.add(dataset_keyword_table)
db.session.commit()
return {}
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:

View File

@@ -20,16 +20,17 @@ class MilvusConfig(BaseModel):
password: str
secure: bool = False
batch_size: int = 100
database: str = "default"
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values.get('host'):
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
if not values.get('port'):
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
if not values.get('user'):
raise ValueError("config MILVUS_USER is required")
if not values['password']:
if not values.get('password'):
raise ValueError("config MILVUS_PASSWORD is required")
return values
@@ -39,7 +40,8 @@ class MilvusConfig(BaseModel):
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
'secure': self.secure,
'db_name': self.database,
}
@@ -128,7 +130,8 @@ class MilvusVector(BaseVector):
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
db_name=self._client_config.database)
from pymilvus import utility
if utility.has_collection(self._collection_name, using=alias):
@@ -140,7 +143,8 @@ class MilvusVector(BaseVector):
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
db_name=self._client_config.database)
from pymilvus import utility
if not utility.has_collection(self._collection_name, using=alias):
@@ -192,7 +196,7 @@ class MilvusVector(BaseVector):
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
password=self._client_config.password)
password=self._client_config.password, db_name=self._client_config.database)
if not utility.has_collection(self._collection_name, using=alias):
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata

View File

@@ -110,6 +110,7 @@ class Vector:
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
database=config.get('MILVUS_DATABASE'),
)
)
else:

View File

@@ -34,6 +34,7 @@ class CSVExtractor(BaseExtractor):
def extract(self) -> list[Document]:
"""Load data into document objects."""
docs = []
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)

View File

@@ -1,59 +0,0 @@
import time
from collections.abc import Mapping
from typing import Any, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
class FakeLLM(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
streaming: bool = False
"""Whether to stream the results or not."""
response: str
@property
def _llm_type(self) -> str:
return "fake-chat-model"
def _call(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
return self.response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"response": self.response}
def get_num_tokens(self, text: str) -> int:
return 0
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
if self.streaming:
for token in output_str:
if run_manager:
run_manager.on_llm_new_token(token)
time.sleep(0.01)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
llm_output = {"token_usage": {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
}}
return ChatResult(generations=[generation], llm_output=llm_output)

View File

@@ -1,46 +0,0 @@
from typing import Any, Optional
from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.rag.retrieval.agent.fake_llm import FakeLLM
class LLMChain(LCLLMChain):
model_config: ModelConfigWithCredentialsEntity
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
parameters: dict[str, Any] = {}
def generate(
self,
input_list: list[dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
stream=False,
stop=stop,
model_parameters=self.parameters
)
generations = [
[Generation(text=result.message.content)]
]
return LLMResult(generations=generations)

View File

@@ -1,179 +0,0 @@
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
from langchain.tools import BaseTool
from pydantic import root_validator
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.rag.retrieval.agent.fake_llm import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_config: ModelConfigWithCredentialsEntity
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
raise e
def real_plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.message.content or "",
additional_kwargs={
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigWithCredentialsEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@@ -1,259 +0,0 @@
import re
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.rag.retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
dataset_tools: Sequence[BaseTool]
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.dataset_tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
raise e
try:
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
elif isinstance(tool_inputs, str):
agent_decision.tool_input = kwargs['input']
else:
agent_decision.return_values['output'] = ''
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigWithCredentialsEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables
)
llm_chain = LLMChain(
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
dataset_tools=tools,
**kwargs,
)

View File

@@ -1,117 +0,0 @@
import logging
from typing import Optional, Union
from langchain.agents import AgentExecutor as LCAgentExecutor
from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigWithCredentialsEntity
tools: list[BaseTool]
summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None
memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
class AgentExecuteResult(BaseModel):
strategy: PlanningStrategy
output: Optional[str]
configuration: AgentConfiguration
class AgentExecutor:
def __init__(self, configuration: AgentConfiguration):
self.configuration = configuration
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
return agent
def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_config,
query
)
if moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
configuration=self.configuration
)
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method,
callbacks=self.configuration.callbacks
)
try:
output = agent_executor.run(input=query)
except InvokeError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed")
output = None
return AgentExecuteResult(
output=output,
strategy=self.configuration.strategy,
configuration=self.configuration
)

View File

@@ -1,5 +1,7 @@
import threading
from typing import Optional, cast
from flask import Flask, current_app
from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
@@ -7,17 +9,35 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
class DatasetRetrieval:
def retrieve(self, tenant_id: str,
def retrieve(self, app_id: str, user_id: str, tenant_id: str,
model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity,
query: str,
@@ -27,6 +47,8 @@ class DatasetRetrieval:
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
"""
Retrieve dataset.
:param app_id: app_id
:param user_id: user_id
:param tenant_id: tenant id
:param model_config: model config
:param config: dataset config
@@ -38,12 +60,22 @@ class DatasetRetrieval:
:return:
"""
dataset_ids = config.dataset_ids
if len(dataset_ids) == 0:
return None
retrieve_config = config.retrieve_config
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model
)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
@@ -59,38 +91,291 @@ class DatasetRetrieval:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
dataset_retriever_tools = self.to_dataset_retriever_tool(
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
continue
available_datasets.append(dataset)
all_documents = []
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query,
model_instance,
model_config, planning_strategy)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from,
available_datasets, query, retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.reranking_model.get('reranking_provider_name'),
retrieve_config.reranking_model.get('reranking_model_name'))
document_score_list = {}
for item in all_documents:
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
if show_retrieve_source:
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
id=segment.dataset_id
).first()
document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': invoke_from.to_source(),
'score': document_score_list.get(segment.index_node_id, None)
}
if invoke_from.to_source() == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
if hit_callback:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ''
def single_retrieve(self, app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
):
tools = []
for dataset in available_datasets:
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description = description.replace('\n', '').replace('\r', '')
message_tool = PromptMessageTool(
name=dataset.id,
description=description,
parameters={
"type": "object",
"properties": {},
"required": [],
}
)
tools.append(message_tool)
dataset_id = None
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
user_id, tenant_id)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if dataset:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get retrieval method
if dataset.indexing_technique == "economy":
retrival_method = 'keyword_search'
else:
retrival_method = retrieval_model_config['search_method']
# get reranking model
reranking_model = retrieval_model_config['reranking_model'] \
if retrieval_model_config['reranking_enable'] else None
# get score threshold
score_threshold = .0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
if results:
self._on_retrival_end(results)
return results
return []
def multiple_retrieve(self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
top_k: int,
score_threshold: float,
reranking_provider_name: str,
reranking_model_name: str):
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
for dataset in available_datasets:
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset.id,
'query': query,
'top_k': top_k,
'all_documents': all_documents,
})
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=show_retrieve_source,
invoke_from=invoke_from,
hit_callback=hit_callback
provider=reranking_provider_name,
model_type=ModelType.RERANK,
model=reranking_model_name
)
if len(dataset_retriever_tools) == 0:
return None
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents,
score_threshold,
top_k)
self._on_query(query, dataset_ids, app_id, user_from, user_id)
if all_documents:
self._on_retrival_end(all_documents)
return all_documents
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=model_config,
tools=dataset_retriever_tools,
memory=memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate"
)
def _on_retrival_end(self, documents: list[Document]) -> None:
"""Handle retrival end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
)
agent_executor = AgentExecutor(agent_configuration)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
result = agent_executor.run(query)
db.session.commit()
return result.output
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
"""
Handle query.
"""
if not query:
return
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id
)
db.session.add(dataset_query)
db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset:
return []
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=top_k
)
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
)
all_documents.extend(documents)
def to_dataset_retriever_tool(self, tenant_id: str,
dataset_ids: list[str],

View File

@@ -12,8 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm.llm_node import LLMNode
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -55,11 +54,10 @@ class ReactMultiDatasetRouter:
self,
query: str,
dataset_tools: list[PromptMessageTool],
node_data: KnowledgeRetrievalNodeData,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
user_id: str,
tenant_id: str,
tenant_id: str
) -> Union[str, None]:
"""Given input, decided what to do.
@@ -72,7 +70,8 @@ class ReactMultiDatasetRouter:
return dataset_tools[0].name
try:
return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance,
return self._react_invoke(query=query, model_config=model_config,
model_instance=model_instance,
tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
except Exception as e:
return None
@@ -80,7 +79,6 @@ class ReactMultiDatasetRouter:
def _react_invoke(
self,
query: str,
node_data: KnowledgeRetrievalNodeData,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
tools: Sequence[PromptMessageTool],
@@ -121,7 +119,7 @@ class ReactMultiDatasetRouter:
model_config=model_config
)
result_text, usage = self._invoke_llm(
node_data=node_data,
completion_param=model_config.parameters,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@@ -134,10 +132,11 @@ class ReactMultiDatasetRouter:
return agent_decision.tool
return None
def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData,
def _invoke_llm(self, completion_param: dict,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]:
stop: list[str], user_id: str, tenant_id: str
) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
@@ -148,7 +147,7 @@ class ReactMultiDatasetRouter:
"""
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data.single_retrieval_config.model.completion_params,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
@@ -203,7 +202,8 @@ class ReactMultiDatasetRouter:
) -> list[ChatModelMessage]:
tool_strings = []
for tool in tools:
tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
tool_strings.append(
f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)

View File

@@ -38,8 +38,10 @@ Action:
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
{{historic_messages}}
Question: {{query}}
Thought: {{agent_scratchpad}}"""
{{agent_scratchpad}}
Thought:"""
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:"""

View File

@@ -1,11 +1,92 @@
from typing import Any
import logging
from typing import Any, Optional
from langchain.utilities import ArxivAPIWrapper
import arxiv
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
logger = logging.getLogger(__name__)
class ArxivAPIWrapper(BaseModel):
"""Wrapper around ArxivAPI.
To use, you should have the ``arxiv`` python package installed.
https://lukasschwab.me/arxiv.py/index.html
This wrapper will use the Arxiv API to conduct searches and
fetch document summaries. By default, it will return the document summaries
of the top-k results.
It limits the Document content by doc_content_chars_max.
Set doc_content_chars_max=None if you don't want to limit the content size.
Args:
top_k_results: number of the top-scored document used for the arxiv tool
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
load_max_docs: a limit to the number of loaded documents
load_all_available_meta:
if True: the `metadata` of the loaded Documents contains all available
meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
if False: the `metadata` contains only the published date, title,
authors and summary.
doc_content_chars_max: an optional cut limit for the length of a document's
content
Example:
.. code-block:: python
arxiv = ArxivAPIWrapper(
top_k_results = 3,
ARXIV_MAX_QUERY_LENGTH = 300,
load_max_docs = 3,
load_all_available_meta = False,
doc_content_chars_max = 40000
)
arxiv.run("tree of thought llm)
"""
arxiv_search = arxiv.Search #: :meta private:
arxiv_exceptions = (
arxiv.ArxivError,
arxiv.UnexpectedEmptyPageError,
arxiv.HTTPError,
) # :meta private:
top_k_results: int = 3
ARXIV_MAX_QUERY_LENGTH = 300
load_max_docs: int = 100
load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000
def run(self, query: str) -> str:
"""
Performs an arxiv search and A single string
with the publish date, title, authors, and summary
for each article separated by two newlines.
If an error occurs or no documents found, error text
is returned instead. Wrapper for
https://lukasschwab.me/arxiv.py/index.html#Search
Args:
query: a plaintext search query
""" # noqa: E501
try:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
).results()
except self.arxiv_exceptions as ex:
return f"Arxiv exception: {ex}"
docs = [
f"Published: {result.updated.date()}\n"
f"Title: {result.title}\n"
f"Authors: {', '.join(a.name for a in result.authors)}\n"
f"Summary: {result.summary}"
for result in results
]
if docs:
return "\n\n".join(docs)[: self.doc_content_chars_max]
else:
return "No good Arxiv Result was found"
class ArxivSearchInput(BaseModel):
query: str = Field(..., description="Search query.")

View File

@@ -12,6 +12,7 @@ class BingSearchTool(BuiltinTool):
def _invoke_bing(self,
user_id: str,
server_url: str,
subscription_key: str, query: str, limit: int,
result_type: str, market: str, lang: str,
filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
@@ -26,7 +27,7 @@ class BingSearchTool(BuiltinTool):
}
query = quote(query)
server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
response = get(server_url, headers=headers)
if response.status_code != 200:
@@ -136,6 +137,7 @@ class BingSearchTool(BuiltinTool):
self._invoke_bing(
user_id='test',
server_url=server_url,
subscription_key=key,
query=query,
limit=limit,
@@ -188,6 +190,7 @@ class BingSearchTool(BuiltinTool):
return self._invoke_bing(
user_id=user_id,
server_url=server_url,
subscription_key=key,
query=query,
limit=limit,

View File

@@ -1,11 +1,95 @@
from typing import Any
import json
from typing import Any, Optional
from langchain.tools import BraveSearch
import requests
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class BraveSearchWrapper(BaseModel):
"""Wrapper around the Brave search engine."""
api_key: str
"""The API key to use for the Brave search engine."""
search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request."""
base_url = "https://api.search.brave.com/res/v1/web/search"
"""The base URL for the Brave search engine."""
def run(self, query: str) -> str:
"""Query the Brave search engine and return the results as a JSON string.
Args:
query: The query to search for.
Returns: The results as a JSON string.
"""
web_search_results = self._search_request(query=query)
final_results = [
{
"title": item.get("title"),
"link": item.get("url"),
"snippet": item.get("description"),
}
for item in web_search_results
]
return json.dumps(final_results)
def _search_request(self, query: str) -> list[dict]:
headers = {
"X-Subscription-Token": self.api_key,
"Accept": "application/json",
}
req = requests.PreparedRequest()
params = {**self.search_kwargs, **{"q": query}}
req.prepare_url(self.base_url, params)
if req.url is None:
raise ValueError("prepared url is None, this should not happen")
response = requests.get(req.url, headers=headers)
if not response.ok:
raise Exception(f"HTTP error {response.status_code}")
return response.json().get("web", {}).get("results", [])
class BraveSearch(BaseModel):
"""Tool that queries the BraveSearch."""
name = "brave_search"
description = (
"a search engine. "
"useful for when you need to answer questions about current events."
" input should be a search query."
)
search_wrapper: BraveSearchWrapper
@classmethod
def from_api_key(
cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any
) -> "BraveSearch":
"""Create a tool from an api key.
Args:
api_key: The api key to use.
search_kwargs: Any additional kwargs to pass to the search wrapper.
**kwargs: Any additional kwargs to pass to the tool.
Returns:
A tool.
"""
wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {})
return cls(search_wrapper=wrapper, **kwargs)
def _run(
self,
query: str,
) -> str:
"""Use the tool."""
return self.search_wrapper.run(query)
class BraveSearchTool(BuiltinTool):
"""
Tool for performing a search using Brave search engine.
@@ -31,7 +115,7 @@ class BraveSearchTool(BuiltinTool):
tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count})
results = tool.run(query)
results = tool._run(query)
if not results:
return self.create_text_message(f"No results found for '{query}' in Tavily")

View File

@@ -1,16 +1,147 @@
from typing import Any
from typing import Any, Optional
from langchain.tools import DuckDuckGoSearchRun
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class DuckDuckGoSearchAPIWrapper(BaseModel):
"""Wrapper for DuckDuckGo Search API.
Free and does not require any setup.
"""
region: Optional[str] = "wt-wt"
safesearch: str = "moderate"
time: Optional[str] = "y"
max_results: int = 5
def get_snippets(self, query: str) -> list[str]:
"""Run query through DuckDuckGo and return concatenated results."""
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = ddgs.text(
query,
region=self.region,
safesearch=self.safesearch,
timelimit=self.time,
)
if results is None:
return ["No good DuckDuckGo Search Result was found"]
snippets = []
for i, res in enumerate(results, 1):
if res is not None:
snippets.append(res["body"])
if len(snippets) == self.max_results:
break
return snippets
def run(self, query: str) -> str:
snippets = self.get_snippets(query)
return " ".join(snippets)
def results(
self, query: str, num_results: int, backend: str = "api"
) -> list[dict[str, str]]:
"""Run query through DuckDuckGo and return metadata.
Args:
query: The query to search for.
num_results: The number of results to return.
Returns:
A list of dictionaries with the following keys:
snippet - The description of the result.
title - The title of the result.
link - The link to the result.
"""
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = ddgs.text(
query,
region=self.region,
safesearch=self.safesearch,
timelimit=self.time,
backend=backend,
)
if results is None:
return [{"Result": "No good DuckDuckGo Search Result was found"}]
def to_metadata(result: dict) -> dict[str, str]:
if backend == "news":
return {
"date": result["date"],
"title": result["title"],
"snippet": result["body"],
"source": result["source"],
"link": result["url"],
}
return {
"snippet": result["body"],
"title": result["title"],
"link": result["href"],
}
formatted_results = []
for i, res in enumerate(results, 1):
if res is not None:
formatted_results.append(to_metadata(res))
if len(formatted_results) == num_results:
break
return formatted_results
class DuckDuckGoSearchRun(BaseModel):
"""Tool that queries the DuckDuckGo search API."""
name = "duckduckgo_search"
description = (
"A wrapper around DuckDuckGo Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
)
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
default_factory=DuckDuckGoSearchAPIWrapper
)
def _run(
self,
query: str,
) -> str:
"""Use the tool."""
return self.api_wrapper.run(query)
class DuckDuckGoSearchResults(BaseModel):
"""Tool that queries the DuckDuckGo search API and gets back json."""
name = "DuckDuckGo Results JSON"
description = (
"A wrapper around Duck Duck Go Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. Output is a JSON array of the query results"
)
num_results: int = 4
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
default_factory=DuckDuckGoSearchAPIWrapper
)
backend: str = "api"
def _run(
self,
query: str,
) -> str:
"""Use the tool."""
res = self.api_wrapper.results(query, self.num_results, backend=self.backend)
res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
return ", ".join([f"[{rs}]" for rs in res_strs])
class DuckDuckGoInput(BaseModel):
query: str = Field(..., description="Search query.")
class DuckDuckGoSearchTool(BuiltinTool):
"""
Tool for performing a search using DuckDuckGo search engine.
@@ -34,7 +165,7 @@ class DuckDuckGoSearchTool(BuiltinTool):
tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput)
result = tool.run(query)
result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id, content=result))

View File

@@ -70,43 +70,44 @@ class SerpAPI:
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if typ == "text":
toret = ""
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
res["answer_box"] = res["answer_box"][0]
res["answer_box"] = res["answer_box"][0] + "\n"
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
toret += res["answer_box"]["answer"] + "\n"
if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret += res["answer_box"]["snippet"] + "\n"
if (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
for item in res["answer_box"]["snippet_highlighted_words"]:
toret += item + "\n"
if (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
toret = res["sports_results"]["game_spotlight"]
elif (
toret += res["sports_results"]["game_spotlight"] + "\n"
if (
"shopping_results" in res.keys()
and "title" in res["shopping_results"][0].keys()
):
toret = res["shopping_results"][:3]
elif (
toret += res["shopping_results"][:3] + "\n"
if (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
toret = res["knowledge_graph"]["description"]
elif "snippet" in res["organic_results"][0].keys():
toret = res["organic_results"][0]["snippet"]
elif "link" in res["organic_results"][0].keys():
toret = res["organic_results"][0]["link"]
elif (
toret = res["knowledge_graph"]["description"] + "\n"
if "snippet" in res["organic_results"][0].keys():
for item in res["organic_results"]:
toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
if (
"images_results" in res.keys()
and "thumbnail" in res["images_results"][0].keys()
):
thumbnails = [item["thumbnail"] for item in res["images_results"][:10]]
toret = thumbnails
else:
if toret == "":
toret = "No good search result found"
elif typ == "link":
if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \

View File

@@ -1,16 +1,187 @@
import json
import time
import urllib.error
import urllib.parse
import urllib.request
from typing import Any
from langchain.tools import PubmedQueryRun
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class PubMedAPIWrapper(BaseModel):
"""
Wrapper around PubMed API.
This wrapper will use the PubMed API to conduct searches and fetch
document summaries. By default, it will return the document summaries
of the top-k results of an input search.
Parameters:
top_k_results: number of the top-scored document used for the PubMed tool
load_max_docs: a limit to the number of loaded documents
load_all_available_meta:
if True: the `metadata` of the loaded Documents gets all available meta info
(see https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch)
if False: the `metadata` gets only the most informative fields.
"""
base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
max_retry = 5
sleep_time = 0.2
# Default values for the parameters
top_k_results: int = 3
load_max_docs: int = 25
ARXIV_MAX_QUERY_LENGTH = 300
doc_content_chars_max: int = 2000
load_all_available_meta: bool = False
email: str = "your_email@example.com"
def run(self, query: str) -> str:
"""
Run PubMed search and get the article meta information.
See https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch
It uses only the most informative fields of article meta information.
"""
try:
# Retrieve the top-k results for the query
docs = [
f"Published: {result['pub_date']}\nTitle: {result['title']}\n"
f"Summary: {result['summary']}"
for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH])
]
# Join the results and limit the character count
return (
"\n\n".join(docs)[:self.doc_content_chars_max]
if docs
else "No good PubMed Result was found"
)
except Exception as ex:
return f"PubMed exception: {ex}"
def load(self, query: str) -> list[dict]:
"""
Search PubMed for documents matching the query.
Return a list of dictionaries containing the document metadata.
"""
url = (
self.base_url_esearch
+ "db=pubmed&term="
+ str({urllib.parse.quote(query)})
+ f"&retmode=json&retmax={self.top_k_results}&usehistory=y"
)
result = urllib.request.urlopen(url)
text = result.read().decode("utf-8")
json_text = json.loads(text)
articles = []
webenv = json_text["esearchresult"]["webenv"]
for uid in json_text["esearchresult"]["idlist"]:
article = self.retrieve_article(uid, webenv)
articles.append(article)
# Convert the list of articles to a JSON string
return articles
def retrieve_article(self, uid: str, webenv: str) -> dict:
url = (
self.base_url_efetch
+ "db=pubmed&retmode=xml&id="
+ uid
+ "&webenv="
+ webenv
)
retry = 0
while True:
try:
result = urllib.request.urlopen(url)
break
except urllib.error.HTTPError as e:
if e.code == 429 and retry < self.max_retry:
# Too Many Requests error
# wait for an exponentially increasing amount of time
print(
f"Too Many Requests, "
f"waiting for {self.sleep_time:.2f} seconds..."
)
time.sleep(self.sleep_time)
self.sleep_time *= 2
retry += 1
else:
raise e
xml_text = result.read().decode("utf-8")
# Get title
title = ""
if "<ArticleTitle>" in xml_text and "</ArticleTitle>" in xml_text:
start_tag = "<ArticleTitle>"
end_tag = "</ArticleTitle>"
title = xml_text[
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
]
# Get abstract
abstract = ""
if "<AbstractText>" in xml_text and "</AbstractText>" in xml_text:
start_tag = "<AbstractText>"
end_tag = "</AbstractText>"
abstract = xml_text[
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
]
# Get publication date
pub_date = ""
if "<PubDate>" in xml_text and "</PubDate>" in xml_text:
start_tag = "<PubDate>"
end_tag = "</PubDate>"
pub_date = xml_text[
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
]
# Return article as dictionary
article = {
"uid": uid,
"title": title,
"summary": abstract,
"pub_date": pub_date,
}
return article
class PubmedQueryRun(BaseModel):
"""Tool that searches the PubMed API."""
name = "PubMed"
description = (
"A wrapper around PubMed.org "
"Useful for when you need to answer questions about Physics, Mathematics, "
"Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
"Electrical Engineering, and Economics "
"from scientific articles on PubMed.org. "
"Input should be a search query."
)
api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper)
def _run(
self,
query: str,
) -> str:
"""Use the Arxiv tool."""
return self.api_wrapper.run(query)
class PubMedInput(BaseModel):
query: str = Field(..., description="Search query.")
class PubMedSearchTool(BuiltinTool):
"""
Tool for performing a search using PubMed search engine.
@@ -34,7 +205,7 @@ class PubMedSearchTool(BuiltinTool):
tool = PubmedQueryRun(args_schema=PubMedInput)
result = tool.run(query)
result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id, content=result))

View File

@@ -1,11 +1,81 @@
from typing import Any, Union
from typing import Any, Optional, Union
from langchain.utilities import TwilioAPIWrapper
from pydantic import BaseModel, validator
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class TwilioAPIWrapper(BaseModel):
"""Messaging Client using Twilio.
To use, you should have the ``twilio`` python package installed,
and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and
``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as
named parameters to the constructor.
Example:
.. code-block:: python
from langchain.utilities.twilio import TwilioAPIWrapper
twilio = TwilioAPIWrapper(
account_sid="ACxxx",
auth_token="xxx",
from_number="+10123456789"
)
twilio.run('test', '+12484345508')
"""
client: Any #: :meta private:
account_sid: Optional[str] = None
"""Twilio account string identifier."""
auth_token: Optional[str] = None
"""Twilio auth token."""
from_number: Optional[str] = None
"""A Twilio phone number in [E.164](https://www.twilio.com/docs/glossary/what-e164)
format, an
[alphanumeric sender ID](https://www.twilio.com/docs/sms/send-messages#use-an-alphanumeric-sender-id),
or a [Channel Endpoint address](https://www.twilio.com/docs/sms/channels#channel-addresses)
that is enabled for the type of message you want to send. Phone numbers or
[short codes](https://www.twilio.com/docs/sms/api/short-code) purchased from
Twilio also work here. You cannot, for example, spoof messages from a private
cell phone number. If you are using `messaging_service_sid`, this parameter
must be empty.
""" # noqa: E501
@validator("client", pre=True, always=True)
def set_validator(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
try:
from twilio.rest import Client
except ImportError:
raise ImportError(
"Could not import twilio python package. "
"Please install it with `pip install twilio`."
)
account_sid = values.get("account_sid")
auth_token = values.get("auth_token")
values["from_number"] = values.get("from_number")
values["client"] = Client(account_sid, auth_token)
return values
def run(self, body: str, to: str) -> str:
"""Run body through Twilio and respond with message sid.
Args:
body: The text of the message you want to send. Can be up to 1,600
characters in length.
to: The destination phone number in
[E.164](https://www.twilio.com/docs/glossary/what-e164) format for
SMS/MMS or
[Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
for other 3rd-party channels.
""" # noqa: E501
message = self.client.messages.create(to, from_=self.from_number, body=body)
return message.sid
class SendMessageTool(BuiltinTool):
"""
A tool for sending messages using Twilio API.

View File

@@ -1,16 +1,79 @@
from typing import Any, Union
from typing import Any, Optional, Union
from langchain import WikipediaAPIWrapper
from langchain.tools import WikipediaQueryRun
from pydantic import BaseModel, Field
import wikipedia
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
WIKIPEDIA_MAX_QUERY_LENGTH = 300
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
class WikipediaAPIWrapper:
"""Wrapper around WikipediaAPI.
To use, you should have the ``wikipedia`` python package installed.
This wrapper will use the Wikipedia API to conduct searches and
fetch page summaries. By default, it will return the page summaries
of the top-k results.
It limits the Document content by doc_content_chars_max.
"""
top_k_results: int = 3
lang: str = "en"
load_all_available_meta: bool = False
doc_content_chars_max: int = 4000
def __init__(self, doc_content_chars_max: int = 4000):
self.doc_content_chars_max = doc_content_chars_max
def run(self, query: str) -> str:
wikipedia.set_lang(self.lang)
wiki_client = wikipedia
"""Run Wikipedia search and get page summaries."""
page_titles = wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])
summaries = []
for page_title in page_titles[: self.top_k_results]:
if wiki_page := self._fetch_page(page_title):
if summary := self._formatted_page_summary(page_title, wiki_page):
summaries.append(summary)
if not summaries:
return "No good Wikipedia Search Result was found"
return "\n\n".join(summaries)[: self.doc_content_chars_max]
@staticmethod
def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]:
return f"Page: {page_title}\nSummary: {wiki_page.summary}"
def _fetch_page(self, page: str) -> Optional[str]:
try:
return wikipedia.page(title=page, auto_suggest=False)
except (
wikipedia.exceptions.PageError,
wikipedia.exceptions.DisambiguationError,
):
return None
class WikipediaQueryRun:
"""Tool that searches the Wikipedia API."""
name = "Wikipedia"
description = (
"A wrapper around Wikipedia. "
"Useful for when you need to answer general questions about "
"people, places, companies, facts, historical events, or other subjects. "
"Input should be a search query."
)
api_wrapper: WikipediaAPIWrapper
def __init__(self, api_wrapper: WikipediaAPIWrapper):
self.api_wrapper = api_wrapper
def _run(
self,
query: str,
) -> str:
"""Use the Wikipedia tool."""
return self.api_wrapper.run(query)
class WikiPediaSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
@@ -24,14 +87,10 @@ class WikiPediaSearchTool(BuiltinTool):
return self.create_text_message('Please input query')
tool = WikipediaQueryRun(
name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput
)
result = tool.run(tool_input={
'query': query
})
result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id,content=result))

View File

@@ -22,9 +22,6 @@ class ValueType(Enum):
class VariablePool:
variables_mapping = {}
user_inputs: dict
system_variables: dict[SystemVariable, Any]
def __init__(self, system_variables: dict[SystemVariable, Any],
user_inputs: dict) -> None:
@@ -34,6 +31,7 @@ class VariablePool:
# 'query': 'abc',
# 'files': []
# }
self.variables_mapping = {}
self.user_inputs = user_inputs
self.system_variables = system_variables
for system_variable, value in system_variables.items():

View File

@@ -234,6 +234,9 @@ class CodeNode(BaseNode):
parameters_validated = {}
for output_name, output_config in output_schema.items():
dot = '.' if prefix else ''
if output_name not in result:
raise ValueError(f'Output {prefix}{dot}{output_name} is missing.')
if output_config.type == 'object':
# check if output is object
if not isinstance(result.get(output_name), dict):

View File

@@ -1,28 +1,21 @@
import threading
from typing import Any, cast
from flask import Flask, current_app
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
default_retrieval_model = {
@@ -106,10 +99,45 @@ class KnowledgeRetrievalNode(BaseNode):
available_datasets.append(dataset)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
all_documents = self._single_retrieve(available_datasets, node_data, query)
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
)
if model_schema:
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
all_documents = dataset_retrieval.single_retrieve(
available_datasets=available_datasets,
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
user_from=self.user_from.value,
query=query,
model_config=model_config,
model_instance=model_instance,
planning_strategy=planning_strategy
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
all_documents = self._multiple_retrieve(available_datasets, node_data, query)
all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
self.user_from.value,
available_datasets, query,
node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_model.provider,
node_data.multiple_retrieval_config.reranking_model.model)
context_list = []
if all_documents:
@@ -184,84 +212,6 @@ class KnowledgeRetrievalNode(BaseNode):
variable_mapping['query'] = node_data.query_variable_selector
return variable_mapping
def _single_retrieve(self, available_datasets, node_data, query):
tools = []
for dataset in available_datasets:
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description = description.replace('\n', '').replace('\r', '')
message_tool = PromptMessageTool(
name=dataset.id,
description=description,
parameters={
"type": "object",
"properties": {},
"required": [],
}
)
tools.append(message_tool)
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
dataset_id = None
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance,
self.user_id, self.tenant_id)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if dataset:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get retrieval method
retrival_method = retrieval_model_config['search_method']
# get reranking model
reranking_model=retrieval_model_config['reranking_model'] \
if retrieval_model_config['reranking_enable'] else None
# get score threshold
score_threshold = .0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
self._on_query(query, [dataset_id])
if results:
self._on_retrival_end(results)
return results
return []
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
"""
@@ -332,112 +282,3 @@ class KnowledgeRetrievalNode(BaseNode):
parameters=completion_params,
stop=stop,
)
def _multiple_retrieve(self, available_datasets, node_data, query):
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
for dataset in available_datasets:
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset.id,
'query': query,
'top_k': node_data.multiple_retrieval_config.top_k,
'all_documents': all_documents,
})
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=node_data.multiple_retrieval_config.reranking_model.provider,
model_type=ModelType.RERANK,
model=node_data.multiple_retrieval_config.reranking_model.model
)
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.top_k)
self._on_query(query, dataset_ids)
if all_documents:
self._on_retrival_end(all_documents)
return all_documents
def _on_retrival_end(self, documents: list[Document]) -> None:
"""Handle retrival end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
db.session.commit()
def _on_query(self, query: str, dataset_ids: list[str]) -> None:
"""
Handle query.
"""
if not query:
return
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source_app_id=self.app_id,
created_by_role=self.user_from.value,
created_by=self.user_id
)
db.session.add(dataset_query)
db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return []
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=top_k
)
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
)
all_documents.extend(documents)

View File

@@ -10,7 +10,7 @@ from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageContentType
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -434,6 +434,22 @@ class LLMNode(BaseNode):
)
stop = model_config.stop
vision_enabled = node_data.vision.enabled
for prompt_message in prompt_messages:
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
prompt_message_content.append(content_item)
elif content_item.type == PromptMessageContentType.TEXT:
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:
prompt_message.content = prompt_message_content
elif (len(prompt_message_content) == 1
and prompt_message_content[0].type == PromptMessageContentType.TEXT):
prompt_message.content = prompt_message_content[0].data
return prompt_messages, stop
@classmethod

View File

@@ -13,7 +13,7 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
@@ -65,7 +65,9 @@ class QuestionClassifierNode(LLMNode):
categories = [_class.name for _class in node_data.classes]
try:
result_text_json = json.loads(result_text.strip('```JSON\n'))
categories = result_text_json.get('categories', [])
categories_result = result_text_json.get('categories', [])
if categories_result:
categories = categories_result
except Exception:
logging.error(f"Failed to parse result text: {result_text}")
try:
@@ -89,14 +91,24 @@ class QuestionClassifierNode(LLMNode):
inputs=variables,
process_data=process_data,
outputs=outputs,
edge_source_handle=classes_map.get(categories[0], None)
edge_source_handle=classes_map.get(categories[0], None),
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
error=str(e),
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
)
@classmethod

View File

@@ -1,9 +1,10 @@
import logging
import time
from typing import Optional
from typing import Optional, cast
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.file.file_obj import FileVar
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
@@ -16,6 +17,7 @@ from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
@@ -219,7 +221,8 @@ class WorkflowEngineManager:
raise ValueError('node id not found in workflow graph')
# Get node class
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
# init workflow run state
node_instance = node_cls(
@@ -252,11 +255,40 @@ class WorkflowEngineManager:
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]
# get value
value = user_inputs.get(variable_key)
# temp fix for image type
if node_type == NodeType.LLM:
new_value = []
if isinstance(value, list):
node_data = node_instance.node_data
node_data = cast(LLMNodeData, node_data)
detail = node_data.vision.configs.detail if node_data.vision.configs else None
for item in value:
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
file = FileVar(
tenant_id=workflow.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=item.get(
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
)
new_value.append(file)
if new_value:
value = new_value
# append variable and value to variable pool
variable_pool.append_variable(
node_id=variable_node_id,
variable_key_list=variable_key_list,
value=user_inputs.get(variable_key)
value=value
)
# run node
node_run_result = node_instance.run(

View File

@@ -28,6 +28,7 @@ def init_app(app: Flask) -> Celery:
celery_app.conf.update(
result_backend=app.config["CELERY_RESULT_BACKEND"],
broker_connection_retry_on_startup=True,
)
if app.config["BROKER_USE_SSL"]:

View File

@@ -815,7 +815,7 @@ class Message(db.Model):
@property
def workflow_run(self):
if self.workflow_run_id:
from api.models.workflow import WorkflowRun
from .workflow import WorkflowRun
return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
return None

View File

@@ -299,6 +299,10 @@ class WorkflowRun(db.Model):
Message.workflow_run_id == self.id
).first()
@property
def workflow(self):
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
class WorkflowNodeExecutionTriggeredFrom(Enum):
"""

View File

@@ -34,20 +34,27 @@ redis[hiredis]~=5.0.3
openpyxl==3.1.2
chardet~=5.1.0
python-docx~=1.1.0
pypdfium2==4.16.0
pypdfium2~=4.17.0
resend~=0.7.0
pyjwt~=2.8.0
anthropic~=0.23.1
newspaper3k==0.2.8
google-api-python-client==2.90.0
wikipedia==1.4.0
readabilipy==0.2.0
google-ai-generativelanguage==0.6.1
google-api-core==2.18.0
google-api-python-client==2.90.0
google-auth==2.29.0
google-auth-httplib2==0.2.0
google-generativeai==0.5.0
google-search-results==2.4.2
googleapis-common-protos==1.63.0
replicate~=0.22.0
websocket-client~=1.7.0
dashscope[tokenizer]~=1.14.0
huggingface_hub~=0.16.4
transformers~=4.31.0
transformers~=4.35.0
tokenizers~=0.15.0
pandas==1.5.3
xinference-client==0.9.4
safetensors==0.3.2
@@ -55,13 +62,12 @@ zhipuai==1.0.7
werkzeug~=3.0.1
pymilvus==2.3.0
qdrant-client==1.7.3
cohere~=4.44
cohere~=5.2.4
pyyaml~=6.0.1
numpy~=1.25.2
unstructured[docx,pptx,msg,md,ppt]~=0.10.27
bs4~=0.0.1
markdown~=3.5.1
google-generativeai~=0.3.2
httpx[socks]~=0.24.1
matplotlib~=3.8.2
yfinance~=0.2.35
@@ -75,4 +81,4 @@ twilio==9.0.0
qrcode~=7.4.2
azure-storage-blob==12.9.0
azure-identity==1.15.0
lxml==5.1.0
lxml==5.1.0

View File

@@ -221,7 +221,8 @@ class AppService:
"name": app.name,
"mode": app.mode,
"icon": app.icon,
"icon_background": app.icon_background
"icon_background": app.icon_background,
"description": app.description
}
}

View File

@@ -1,5 +1,8 @@
from typing import Optional, Union
from sqlalchemy import or_
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@@ -13,8 +16,9 @@ class ConversationService:
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
last_id: Optional[str], limit: int,
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None,
exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
invoke_from: InvokeFrom,
include_ids: Optional[list] = None,
exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@@ -24,6 +28,7 @@ class ConversationService:
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value)
)
if include_ids is not None:
@@ -32,9 +37,6 @@ class ConversationService:
if exclude_ids is not None:
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
if exclude_debug_conversation:
base_query = base_query.filter(Conversation.override_model_configs == None)
if last_id:
last_conversation = base_query.filter(
Conversation.id == last_id,

Some files were not shown because too many files have changed in this diff Show More