mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 05:54:02 +00:00
Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fa0e4072d | ||
|
|
e15d18aa1c | ||
|
|
164ef26a60 | ||
|
|
0dada847ef | ||
|
|
36b7dbb8d0 | ||
|
|
02e483c99b | ||
|
|
afe30e15a0 | ||
|
|
9a1ea9ac03 | ||
|
|
693647a141 | ||
|
|
cea107b165 | ||
|
|
509c640a80 | ||
|
|
617e7cee81 | ||
|
|
d87d4b9b56 | ||
|
|
c889717d24 | ||
|
|
1f302990c6 | ||
|
|
37024afe9c | ||
|
|
18b855140d | ||
|
|
7c520b52c1 | ||
|
|
b98e363a5c | ||
|
|
0a7ea9d206 | ||
|
|
3d473b9763 | ||
|
|
e0df7505f6 | ||
|
|
43bb0b0b93 | ||
|
|
6164604462 | ||
|
|
826c422ac4 | ||
|
|
bf63a43bda | ||
|
|
55fc46c707 | ||
|
|
5102430a68 | ||
|
|
0f897bc1f9 | ||
|
|
d948b0b49b | ||
|
|
b6de97ad53 | ||
|
|
8cefa6b82e | ||
|
|
81e1b3fc61 | ||
|
|
4c1cfd9278 | ||
|
|
14bb0b02ac | ||
|
|
240c793e7a | ||
|
|
89a853212b | ||
|
|
97d1e0bbbb | ||
|
|
cfb5ccc7d3 | ||
|
|
835e547195 | ||
|
|
af9ccb7072 | ||
|
|
74de7cf33c | ||
|
|
f5e65b98a9 | ||
|
|
eb76d7a226 | ||
|
|
e635f3dc1d | ||
|
|
2a6b7d57cb | ||
|
|
8bb225bec6 | ||
|
|
39d3fc4742 | ||
|
|
5c98260cec | ||
|
|
f599f41336 | ||
|
|
1384a6d0fd | ||
|
|
28089c98c1 | ||
|
|
10d6d50b6c | ||
|
|
752f6fb15a | ||
|
|
8fcf459285 | ||
|
|
9c01bcb3e5 | ||
|
|
a2c068d949 | ||
|
|
4ad3f2cdc2 | ||
|
|
29918c498c | ||
|
|
269432a5e6 | ||
|
|
a33b774314 | ||
|
|
cc5ccaaca1 | ||
|
|
33ea689861 | ||
|
|
581836b716 | ||
|
|
0516b78d6f | ||
|
|
84d7cbf916 | ||
|
|
f514fd2182 | ||
|
|
86707928d4 | ||
|
|
3c3fb3cd3f | ||
|
|
337899a03d | ||
|
|
bae0c071cd | ||
|
|
19cb3c7871 | ||
|
|
2e4dec365d | ||
|
|
ca3e2e6cc0 | ||
|
|
283979fc46 | ||
|
|
a81c1ab6ae | ||
|
|
48d4d55ecc | ||
|
|
b7691f5658 | ||
|
|
1382f10433 | ||
|
|
d8db728c33 | ||
|
|
d2259f20cb | ||
|
|
9720d6b7a5 |
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' && startsWith(github.ref, 'refs/tags/') }}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
@@ -36,7 +36,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
||||
| Feature Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| High-Priority Features as being labeled by a team member | High Priority |
|
||||
| Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority |
|
||||
| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority |
|
||||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@
|
||||
| Feature Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| High-Priority Features as being labeled by a team member | High Priority |
|
||||
| Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority |
|
||||
| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority |
|
||||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
|
||||
|
||||
271
README.md
271
README.md
@@ -1,95 +1,176 @@
|
||||
[](https://dify.ai)
|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md">English</a> |
|
||||
<a href="./README_CN.md">简体中文</a> |
|
||||
<a href="./README_JA.md">日本語</a> |
|
||||
<a href="./README_ES.md">Español</a> |
|
||||
<a href="./README_KL.md">Klingon</a> |
|
||||
<a href="./README_FR.md">Français</a>
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
|
||||
<a href="https://docs.dify.ai">Documentation</a> ·
|
||||
<a href="https://cal.com/guchenhe/30min">Commercial inquiry</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/AI-Dify?logo=AI&logoColor=%20%23f5f5f5&label=Dify&labelColor=%20%23155EEF&color=%23EAECF0"></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Product-F04438"></a>
|
||||
<a href="https://dify.ai/pricing" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/free-pricing?logo=free&color=%20%23155EEF&label=pricing&labelColor=%20%23528bff"></a>
|
||||
<a href="https://discord.gg/FngNHpbcY7" target="_blank">
|
||||
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord"
|
||||
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb"
|
||||
alt="chat on Discord"></a>
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?style=social&logo=X"
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on Twitter"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
<img alt="Commits last month" src="https://img.shields.io/github/commit-activity/m/langgenius/dify?labelColor=%20%2332b583&color=%20%2312b76a"></a>
|
||||
<a href="https://github.com/langgenius/dify/" target="_blank">
|
||||
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
|
||||
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
|
||||
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6" target="_blank">
|
||||
📌 Check out Dify Premium on AWS and deploy it to your own AWS VPC with one-click.
|
||||
</a>
|
||||
<a href="./README.md"><img alt="Commits last month" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="Commits last month" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="Commits last month" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="Commits last month" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="Commits last month" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="Commits last month" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
**Dify** is an open-source LLM app development platform. Dify's intuitive interface combines a RAG pipeline, AI workflow orchestration, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
|
||||
#
|
||||
|
||||

|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
|
||||
</br> </br>
|
||||
|
||||
**1. Workflow**:
|
||||
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
## Using our Cloud Services
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabilities of the self-deployed version, and includes 200 free requests to OpenAI GPT-3.5.
|
||||

|
||||
|
||||
### Looking to purchase via AWS?
|
||||
Check out [Dify Premium on AWS](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click.
|
||||
|
||||
## Dify vs. LangChain vs. Assistants API
|
||||
**3. Prompt IDE**:
|
||||
Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app.
|
||||
|
||||
| Feature | Dify.AI | Assistants API | LangChain |
|
||||
|---------|---------|----------------|-----------|
|
||||
| **Programming Approach** | API-oriented | API-oriented | Python Code-oriented |
|
||||
| **Ecosystem Strategy** | Open Source | Close Source | Open Source |
|
||||
| **RAG Engine** | Supported | Supported | Not Supported |
|
||||
| **Prompt IDE** | Included | Included | None |
|
||||
| **Supported LLMs** | Rich Variety | OpenAI-only | Rich Variety |
|
||||
| **Local Deployment** | Supported | Not Supported | Not Applicable |
|
||||
**4. RAG Pipeline**:
|
||||
Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats.
|
||||
|
||||
**5. Agent capabilities**:
|
||||
You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DELL·E, Stable Diffusion and WolframAlpha.
|
||||
|
||||
**6. LLMOps**:
|
||||
Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations.
|
||||
|
||||
**7. Backend-as-a-Service**:
|
||||
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
|
||||
|
||||
|
||||
## Feature Comparison
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Feature</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI Assistants API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Programming Approach</td>
|
||||
<td align="center">API + App-oriented</td>
|
||||
<td align="center">Python Code</td>
|
||||
<td align="center">App-oriented</td>
|
||||
<td align="center">API-oriented</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Supported LLMs</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">OpenAI-only</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG Engine</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Workflow</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Observability</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Enterprise Feature (SSO/Access control)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Local Deployment</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Using Dify
|
||||
|
||||
- **Cloud </br>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
|
||||
- **Self-hosting Dify Community Edition</br>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for Enterprise / Organizations</br>**
|
||||
We provide additional enterprise-centric features. [Schedule a meeting with us](https://cal.com/guchenhe/30min) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
|
||||
## Staying ahead
|
||||
|
||||
Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
## Features
|
||||
## Quick Start
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4GB
|
||||
|
||||

|
||||
|
||||
**1. LLM Support**: Integration with OpenAI's GPT family of models, or the open-source Llama2 family models. In fact, Dify supports mainstream commercial models and open-source models (locally deployed or based on MaaS).
|
||||
|
||||
**2. Prompt IDE**: Visual orchestration of applications and services based on LLMs with your team.
|
||||
|
||||
**3. RAG Engine**: Includes various RAG capabilities based on full-text indexing or vector database embeddings, allowing direct upload of PDFs, TXTs, and other text formats.
|
||||
|
||||
**4. AI Agent**: Based on Function Calling and ReAct, the Agent inference framework allows users to customize tools, what you see is what you get. Dify provides more than a dozen built-in tool calling capabilities, such as Google Search, DELL·E, Stable Diffusion, WolframAlpha, etc.
|
||||
|
||||
|
||||
**5. Continuous Operations**: Monitor and analyze application logs and performance, continuously improving Prompts, datasets, or models using production data.
|
||||
|
||||
## Before You Start
|
||||
|
||||
**Star us on GitHub, and be instantly notified for new releases!**
|
||||
|
||||

|
||||
|
||||
- [Website](https://dify.ai)
|
||||
- [Docs](https://docs.dify.ai)
|
||||
- [Deployment Docs](https://docs.dify.ai/getting-started/install-self-hosted)
|
||||
- [FAQ](https://docs.dify.ai/getting-started/faq)
|
||||
|
||||
|
||||
## Install the Community Edition
|
||||
|
||||
### System Requirements
|
||||
|
||||
Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
|
||||
- CPU >= 2 Core
|
||||
- RAM >= 4GB
|
||||
|
||||
### Quick Start
|
||||
</br>
|
||||
|
||||
The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
@@ -98,61 +179,63 @@ cd docker
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization installation process.
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
|
||||
|
||||
#### Deploy with Helm Chart
|
||||
> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
|
||||
|
||||
[Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
|
||||
## Next steps
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run `docker-compose up -d` again. You can see the full list of environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
### Configuration
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run `docker-compose up -d` again. You can see the full list of environment variables in our [docs](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
|
||||
|
||||
### Projects made by community
|
||||
|
||||
- [Chatbot Chrome Extension by @charli117](https://github.com/langgenius/chatbot-chrome-extension)
|
||||
> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
### Contributors
|
||||
**Contributors**
|
||||
|
||||
<a href="https://github.com/langgenius/dify/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
|
||||
</a>
|
||||
|
||||
### Translations
|
||||
## Community & Contact
|
||||
|
||||
We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
## Community & Support
|
||||
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and checking out our feature roadmap.
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [Email Support](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI.
|
||||
* [Email](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI.
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
* [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
* [Business Contact](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Best for: business inquiries of licensing Dify.AI for commercial use.
|
||||
|
||||
### Direct Meetings
|
||||
Or, schedule a meeting directly with a team member:
|
||||
|
||||
**Help us make Dify better. Reach out directly to us**.
|
||||
<table>
|
||||
<tr>
|
||||
<th>Point of Contact</th>
|
||||
<th>Purpose</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href='https://cal.com/guchenhe/15min' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/9ebcd111-1205-4d71-83d5-948d70b809f5' alt='Git-Hub-README-Button-3x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
|
||||
<td>Business enquiries & product feedback</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href='https://cal.com/pinkbanana' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/d1edd00a-d7e4-4513-be6c-e57038e143fd' alt='Git-Hub-README-Button-2x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
|
||||
<td>Contributions, issues & feature requests</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
| Point of Contact | Purpose |
|
||||
| :----------------------------------------------------------: | :----------------------------------------------------------: |
|
||||
| <a href='https://cal.com/guchenhe/15min' target='_blank'><img src='https://i.postimg.cc/fWBqSmjP/Git-Hub-README-Button-3x.png' border='0' alt='Git-Hub-README-Button-3x' height="60" width="214"/></a> | Product design feedback, user experience discussions, feature planning and roadmaps. |
|
||||
| <a href='https://cal.com/pinkbanana' target='_blank'><img src='https://i.postimg.cc/LsRTh87D/Git-Hub-README-Button-2x.png' border='0' alt='Git-Hub-README-Button-2x' height="60" width="225"/></a> | Technical support, issues, or feature requests |
|
||||
|
||||
## Security Disclosure
|
||||
|
||||
|
||||
@@ -21,6 +21,10 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://mp.weixin.qq.com/s/TnyfIuH-tPi9o1KNjwVArw" target="_blank">
|
||||
Dify 发布 AI Agent 能力:基于不同的大型语言模型构建 GPTs 和 Assistants
|
||||
|
||||
@@ -115,6 +115,7 @@ docker compose up -d
|
||||
|
||||
Difyに貢献していただき、コードの提出、問題の報告、新しいアイデアの提供、またはDifyを基に作成した興味深く有用なAIアプリケーションの共有により、Difyをより良いものにするお手伝いを歓迎します。同時に、さまざまなイベント、会議、ソーシャルメディアでDifyを共有することも歓迎します。
|
||||
|
||||
- [Github Discussion](https://github.com/langgenius/dify/discussions). 👉:アプリを共有し、コミュニティとコミュニケーション。
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues)。最適な使用法:Dify.AIの使用中に遭遇するバグやエラー、[貢献ガイド](CONTRIBUTING.md)を参照。
|
||||
- [Email サポート](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify)。最適な使用法:Dify.AIの使用に関する質問。
|
||||
- [Discord](https://discord.gg/FngNHpbcY7)。最適な使用法:アプリケーションの共有とコミュニティとの交流。
|
||||
|
||||
@@ -11,7 +11,8 @@ RUN apt-get update \
|
||||
|
||||
COPY requirements.txt /requirements.txt
|
||||
|
||||
RUN pip install --prefix=/pkg -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --prefix=/pkg -r requirements.txt
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
@@ -17,16 +17,16 @@
|
||||
```bash
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
3.5 If you use Anaconda, create a new environment and activate it
|
||||
4. If you use Anaconda, create a new environment and activate it
|
||||
```bash
|
||||
conda create --name dify python=3.10
|
||||
conda activate dify
|
||||
```
|
||||
4. Install dependencies
|
||||
5. Install dependencies
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
5. Run migrate
|
||||
6. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
|
||||
@@ -47,9 +47,11 @@
|
||||
pip install -r requirements.txt --upgrade --force-reinstall
|
||||
```
|
||||
|
||||
6. Start backend:
|
||||
7. Start backend:
|
||||
```bash
|
||||
flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
|
||||
8. If you need to debug local async processing, you can run `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
|
||||
8. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
|
||||
9. If you need to debug local async processing, please start the worker service by running
|
||||
`celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`.
|
||||
The started celery app handles the async tasks, e.g. dataset importing and documents indexing.
|
||||
|
||||
@@ -42,7 +42,7 @@ DEFAULTS = {
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
|
||||
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-turbo-2024-04-09,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
@@ -67,6 +67,7 @@ DEFAULTS = {
|
||||
'CODE_EXECUTION_ENDPOINT': '',
|
||||
'CODE_EXECUTION_API_KEY': '',
|
||||
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
||||
'MILVUS_DATABASE': 'default',
|
||||
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
||||
}
|
||||
|
||||
@@ -98,7 +99,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.6.0"
|
||||
self.CURRENT_VERSION = "0.6.2"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -212,6 +213,7 @@ class Config:
|
||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||
self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')
|
||||
|
||||
# weaviate settings
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
|
||||
@@ -6,6 +6,7 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
@@ -39,8 +40,8 @@ class ConversationListApi(InstalledAppResource):
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
exclude_debug_conversation=True
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with, Resource
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, AppMode
|
||||
from models.tools import ApiToolProvider
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
@@ -92,6 +89,16 @@ class AppMetaApi(Resource):
|
||||
"""Get app meta"""
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
class AppInfoApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app infomation"""
|
||||
return {
|
||||
'name':app_model.name,
|
||||
'description':app_model.description
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, '/parameters')
|
||||
api.add_resource(AppMetaApi, '/meta')
|
||||
api.add_resource(AppInfoApi, '/info')
|
||||
|
||||
@@ -6,6 +6,7 @@ import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
@@ -27,7 +28,13 @@ class ConversationApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
|
||||
return ConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
@@ -37,7 +38,8 @@ class ConversationListApi(WebApiResource):
|
||||
user=end_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@@ -22,6 +24,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@@ -37,7 +40,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageAgentThought
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,6 +48,7 @@ logger = logging.getLogger(__name__)
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
@@ -72,6 +76,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.app_config = app_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
@@ -118,6 +123,12 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
# check if model supports vision
|
||||
if model_schema and ModelFeature.VISION in (model_schema.features or []):
|
||||
self.files = application_generate_entity.files
|
||||
else:
|
||||
self.files = []
|
||||
|
||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||
-> AgentChatAppGenerateEntity:
|
||||
"""
|
||||
@@ -227,6 +238,34 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
||||
"""
|
||||
Init tools
|
||||
"""
|
||||
tool_instances = {}
|
||||
prompt_messages_tools = []
|
||||
|
||||
for tool in self.app_config.agent.tools if self.app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||
"""
|
||||
update prompt message tool
|
||||
@@ -314,7 +353,7 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, str],
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
@@ -412,15 +451,19 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
result = []
|
||||
# check if there is a system message in the beginning of the conversation
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
result.append(prompt_messages[0])
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
result.append(prompt_message)
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
|
||||
for message in messages:
|
||||
result.append(UserPromptMessage(content=message.query))
|
||||
if message.id == self.message.id:
|
||||
continue
|
||||
|
||||
result.append(self.organize_agent_user_prompt(message))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
@@ -471,3 +514,32 @@ class BaseAgentRunner(AppRunner):
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
message_file_parser = MessageFileParser(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
)
|
||||
|
||||
files = message.message_files
|
||||
if files:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files,
|
||||
file_extra_config
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
||||
@@ -1,33 +1,36 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union
|
||||
from typing import Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message
|
||||
from models.model import Message
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner):
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
def run(self, message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
@@ -36,9 +39,7 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
"""
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||
self._init_react_state(query)
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_config.stop:
|
||||
@@ -47,38 +48,19 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
# override inputs
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in app_config.agent.tools if app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
@@ -103,7 +85,7 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
self._prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
|
||||
@@ -120,18 +102,8 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt messages
|
||||
prompt_messages = self._organize_cot_prompt_messages(
|
||||
mode=app_generate_entity.model_config.mode,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_messages_tools,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
agent_prompt_message=app_config.agent.prompt,
|
||||
instruction=instruction,
|
||||
input=query
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
@@ -149,7 +121,7 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = self._handle_stream_react(chunks, usage_dict)
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
@@ -165,30 +137,12 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, dict):
|
||||
scratchpad.agent_response += json.dumps(chunk)
|
||||
try:
|
||||
if scratchpad.action:
|
||||
raise Exception("")
|
||||
scratchpad.action_str = json.dumps(chunk)
|
||||
scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=chunk['action'],
|
||||
action_input=chunk['action_input']
|
||||
)
|
||||
except:
|
||||
scratchpad.thought += json.dumps(chunk)
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
scratchpad.agent_response += json.dumps(chunk.dict())
|
||||
scratchpad.action_str = json.dumps(chunk.dict())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
scratchpad.agent_response += chunk
|
||||
scratchpad.thought += chunk
|
||||
@@ -206,27 +160,29 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else '',
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict['usage'])
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict['usage']
|
||||
)
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
@@ -238,106 +194,43 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
final_answer = scratchpad.action.action_input if \
|
||||
isinstance(scratchpad.action.action_input, str) else \
|
||||
json.dumps(scratchpad.action.action_input)
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(scratchpad.action.action_input)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
else:
|
||||
function_call_state = True
|
||||
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = scratchpad.action.action_name
|
||||
tool_call_args = scratchpad.action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
tool_invoke_meta=ToolInvokeMeta.error_instance(
|
||||
f"there is not a tool named {tool_call_name}"
|
||||
).to_dict(),
|
||||
thought=None,
|
||||
observation={
|
||||
tool_call_name: answer
|
||||
},
|
||||
answer=answer,
|
||||
messages_ids=[]
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
|
||||
action=scratchpad.action,
|
||||
tool_instances=tool_instances,
|
||||
message_file_ids=message_file_ids
|
||||
)
|
||||
scratchpad.observation = tool_invoke_response
|
||||
scratchpad.agent_response = tool_invoke_response
|
||||
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback
|
||||
)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta=tool_invoke_meta.to_dict(),
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict['usage']
|
||||
)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name,
|
||||
value=message_file.id,
|
||||
name=save_as)
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
|
||||
observation = tool_invoke_response
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_name,
|
||||
tool_input={
|
||||
tool_call_name: tool_call_args
|
||||
},
|
||||
tool_invoke_meta={
|
||||
tool_call_name: tool_invoke_meta.to_dict()
|
||||
},
|
||||
thought=None,
|
||||
observation={
|
||||
tool_call_name: observation
|
||||
},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
@@ -379,96 +272,63 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
|
||||
-> Generator[Union[str, dict], None, None]:
|
||||
def parse_json(json_str):
|
||||
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
:param tool_instances: tool instances
|
||||
:return: observation, meta
|
||||
"""
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = action.action_name
|
||||
tool_call_args = action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
return answer, ToolInvokeMeta.error_instance(answer)
|
||||
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
return json.loads(json_str.strip())
|
||||
except:
|
||||
return json_str
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
yield parse_json(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
for response in llm_response:
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# stream
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
if delta == '`':
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback
|
||||
)
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
if json_quote_count == 0:
|
||||
in_json = False
|
||||
got_json = True
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if in_json:
|
||||
json_cache += delta
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_json(json_cache)
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
index += steps
|
||||
|
||||
if code_block_cache:
|
||||
yield code_block_cache
|
||||
|
||||
if json_cache:
|
||||
yield parse_json(json_cache)
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action['action'],
|
||||
action_input=action['action_input']
|
||||
)
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
@@ -482,15 +342,46 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
|
||||
return instruction
|
||||
|
||||
def _init_agent_scratchpad(self,
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
messages: list[PromptMessage]
|
||||
) -> list[AgentScratchpadUnit]:
|
||||
def _init_react_state(self, query) -> None:
|
||||
"""
|
||||
init agent scratchpad
|
||||
"""
|
||||
self._query = query
|
||||
self._agent_scratchpad = []
|
||||
self._historic_prompt_messages = self._organize_historic_prompt_messages()
|
||||
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
"""
|
||||
message = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
else:
|
||||
message += f"Thought: {scratchpad.thought}\n\n"
|
||||
if scratchpad.action_str:
|
||||
message += f"Action: {scratchpad.action_str}\n\n"
|
||||
if scratchpad.observation:
|
||||
message += f"Observation: {scratchpad.observation}\n\n"
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpad: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
for message in messages:
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
@@ -505,186 +396,29 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments)
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(
|
||||
current_scratchpad.action.to_dict()
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad.append(current_scratchpad)
|
||||
|
||||
scratchpad.append(current_scratchpad)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
current_scratchpad.observation = message.content
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
result.append(message)
|
||||
|
||||
if scratchpad:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpad)
|
||||
))
|
||||
|
||||
scratchpad = []
|
||||
|
||||
if scratchpad:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpad)
|
||||
))
|
||||
|
||||
return agent_scratchpad
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
):
|
||||
"""
|
||||
check chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
next_iteration = agent_prompt_message.next_iteration
|
||||
|
||||
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
|
||||
raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
|
||||
|
||||
# check instruction, tools, and tool_names slots
|
||||
if not first_prompt.find("{{instruction}}") >= 0:
|
||||
raise ValueError("{{instruction}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tools}}") >= 0:
|
||||
raise ValueError("{{tools}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tool_names}}") >= 0:
|
||||
raise ValueError("{{tool_names}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not first_prompt.find("{{query}}") >= 0:
|
||||
raise ValueError("{{query}} is required in first_prompt")
|
||||
if not first_prompt.find("{{agent_scratchpad}}") >= 0:
|
||||
raise ValueError("{{agent_scratchpad}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
next_iteration = self.app_config.agent.prompt.next_iteration
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
|
||||
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
|
||||
|
||||
return result
|
||||
|
||||
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
self._check_cot_prompt_messages(mode, agent_prompt_message)
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
|
||||
# parse tools
|
||||
tools_str = self._jsonify_tool_prompt_messages(tools)
|
||||
|
||||
# parse tools name
|
||||
tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
|
||||
|
||||
# get system message
|
||||
system_message = first_prompt.replace("{{instruction}}", instruction) \
|
||||
.replace("{{tools}}", tools_str) \
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
|
||||
# organize prompt messages
|
||||
if mode == "chat":
|
||||
# override system message
|
||||
overridden = False
|
||||
prompt_messages = prompt_messages.copy()
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message.content = system_message
|
||||
overridden = True
|
||||
break
|
||||
|
||||
# convert tool prompt messages to user prompt messages
|
||||
for idx, prompt_message in enumerate(prompt_messages):
|
||||
if isinstance(prompt_message, ToolPromptMessage):
|
||||
prompt_messages[idx] = UserPromptMessage(
|
||||
content=prompt_message.content
|
||||
)
|
||||
|
||||
if not overridden:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=system_message,
|
||||
))
|
||||
|
||||
# add assistant message
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
|
||||
))
|
||||
|
||||
# add user message
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
|
||||
))
|
||||
|
||||
self._is_first_iteration = False
|
||||
|
||||
return prompt_messages
|
||||
elif mode == "completion":
|
||||
# parse agent scratchpad
|
||||
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
|
||||
self._is_first_iteration = False
|
||||
# parse prompt messages
|
||||
return [UserPromptMessage(
|
||||
content=first_prompt.replace("{{instruction}}", instruction)
|
||||
.replace("{{tools}}", tools_str)
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
.replace("{{query}}", input)
|
||||
.replace("{{agent_scratchpad}}", agent_scratchpad_str),
|
||||
)]
|
||||
else:
|
||||
raise ValueError(f"mode {mode} is not supported")
|
||||
|
||||
def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
|
||||
"""
|
||||
jsonify tool prompt messages
|
||||
"""
|
||||
tools = jsonable_encoder(tools)
|
||||
try:
|
||||
return json.dumps(tools, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps(tools)
|
||||
return result
|
||||
71
api/core/agent/cot_chat_agent_runner.py
Normal file
71
api/core/agent/cot_chat_agent_runner.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotChatAgentRunner(CotAgentRunner):
|
||||
def _organize_system_prompt(self) -> SystemPromptMessage:
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt \
|
||||
.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._historic_prompt_messages
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content='')
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
assistant_messages = [assistant_message]
|
||||
|
||||
# query messages
|
||||
query_messages = UserPromptMessage(content=self._query)
|
||||
|
||||
if assistant_messages:
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
]
|
||||
else:
|
||||
messages = [system_message, *historic_messages, query_messages]
|
||||
|
||||
# join all messages
|
||||
return messages
|
||||
69
api/core/agent/cot_completion_agent_runner.py
Normal file
69
api/core/agent/cot_completion_agent_runner.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotCompletionAgentRunner(CotAgentRunner):
|
||||
def _organize_instruction_prompt(self) -> str:
|
||||
"""
|
||||
Organize instruction prompt
|
||||
"""
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self) -> str:
|
||||
"""
|
||||
Organize historic prompt
|
||||
"""
|
||||
historic_prompt_messages = self._historic_prompt_messages
|
||||
historic_prompt = ""
|
||||
|
||||
for message in historic_prompt_messages:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
|
||||
return historic_prompt
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
# organize system prompt
|
||||
system_prompt = self._organize_instruction_prompt()
|
||||
|
||||
# organize historic prompt messages
|
||||
historic_prompt = self._organize_historic_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ''
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assistant_prompt += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_prompt += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_prompt += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
# query messages
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = system_prompt \
|
||||
.replace("{{historic_messages}}", historic_prompt) \
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt) \
|
||||
.replace("{{query}}", query_prompt)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
@@ -34,12 +34,29 @@ class AgentScratchpadUnit(BaseModel):
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convert to dictionary.
|
||||
"""
|
||||
return {
|
||||
'action': self.action_name,
|
||||
'action_input': self.action_input,
|
||||
}
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
def is_final(self) -> bool:
|
||||
"""
|
||||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
return self.action is None or (
|
||||
'final' in self.action.action_name.lower() and
|
||||
'answer' in self.action.action_name.lower()
|
||||
)
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@@ -10,21 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
def run(self,
|
||||
message: Message, query: str, **kwargs: Any
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
@@ -35,40 +36,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
prompt_template = app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_messages = self.history_prompt_messages
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=query,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
|
||||
prompt_messages = self._organize_user_query(query, prompt_messages)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in app_config.agent.tools if app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: list[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
@@ -207,19 +185,25 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[]
|
||||
)
|
||||
if tool_calls:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content='',
|
||||
name='',
|
||||
tool_calls=[AssistantPromptMessage.ToolCall(
|
||||
assistant_message.tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1],
|
||||
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
)
|
||||
) for tool_call in tool_calls]
|
||||
))
|
||||
) for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
prompt_messages.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
@@ -239,12 +223,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
final_answer += response + '\n'
|
||||
|
||||
# update prompt messages
|
||||
if response.strip():
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=response,
|
||||
))
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
@@ -287,9 +265,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=None,
|
||||
prompt_messages = self._organize_assistant_message(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_response=tool_response['tool_response'],
|
||||
@@ -324,6 +300,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
@@ -386,29 +364,68 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def organize_prompt_messages(self, prompt_template: str,
|
||||
query: str = None,
|
||||
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
Initialize system message
|
||||
"""
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = [
|
||||
if not prompt_messages and prompt_template:
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
if tool_response:
|
||||
prompt_messages = prompt_messages.copy()
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize assistant message
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
if tool_response is not None:
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
content.data if content.type == PromptMessageContentType.TEXT else
|
||||
'[image]' if content.type == PromptMessageContentType.IMAGE else
|
||||
'[file]'
|
||||
for content in prompt_message.content
|
||||
])
|
||||
|
||||
return prompt_messages
|
||||
183
api/core/agent/output_parser/cot_output_parser.py
Normal file
183
api/core/agent/output_parser/cot_output_parser.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None]) -> \
|
||||
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(json_str):
|
||||
try:
|
||||
action = json.loads(json_str)
|
||||
action_name = None
|
||||
action_input = None
|
||||
|
||||
for key, value in action.items():
|
||||
if 'input' in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
|
||||
if action_name is not None and action_input is not None:
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json_str or ''
|
||||
except:
|
||||
return json_str or ''
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
yield parse_action(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
action_cache = ''
|
||||
action_str = 'action:'
|
||||
action_idx = 0
|
||||
|
||||
thought_cache = ''
|
||||
thought_str = 'thought:'
|
||||
thought_idx = 0
|
||||
|
||||
for response in llm_response:
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
|
||||
# stream
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
last_character = response[index-1] if index > 0 else ''
|
||||
|
||||
if delta == '`':
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
elif delta.lower() == action_str[action_idx] and action_idx > 0:
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
yield action_cache
|
||||
action_cache = ''
|
||||
action_idx = 0
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
yield thought_cache
|
||||
thought_cache = ''
|
||||
thought_idx = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
if json_quote_count == 0:
|
||||
in_json = False
|
||||
got_json = True
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if in_json:
|
||||
json_cache += delta
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
|
||||
index += steps
|
||||
|
||||
if code_block_cache:
|
||||
yield code_block_cache
|
||||
|
||||
if json_cache:
|
||||
yield parse_action(json_cache)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
@@ -11,8 +12,8 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, Mo
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
@@ -207,48 +208,40 @@ class AgentChatAppRunner(AppRunner):
|
||||
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
assistant_cot_runner = CotAgentRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
)
|
||||
invoke_result = assistant_cot_runner.run(
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
)
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
assistant_fc_runner = FunctionCallAgentRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
)
|
||||
invoke_result = assistant_fc_runner.run(
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
)
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
invoke_result = runner.run(
|
||||
message=message,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
|
||||
@@ -156,6 +156,8 @@ class ChatAppRunner(AppRunner):
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
context = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_config,
|
||||
config=app_config.dataset,
|
||||
|
||||
@@ -116,6 +116,8 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
context = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_config,
|
||||
config=dataset_config,
|
||||
|
||||
@@ -122,7 +122,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||
id: str
|
||||
metadata: Optional[dict] = None
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class MessageFileStreamResponse(StreamResponse):
|
||||
|
||||
@@ -1,12 +1,32 @@
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, TextIO, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from pydantic import BaseModel
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def get_colored_text(text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
|
||||
|
||||
def print_text(
|
||||
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
||||
) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end, file=file)
|
||||
if file:
|
||||
file.flush() # ensure all printed content are written to file
|
||||
|
||||
class DifyAgentCallbackHandler(BaseModel):
|
||||
"""Callback Handler that prints to std out."""
|
||||
color: Optional[str] = ''
|
||||
current_loop = 1
|
||||
|
||||
@@ -41,7 +41,8 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_embeddings = []
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
|
||||
model_schema = model_type_instance.get_model_schema(self._model_instance.model,
|
||||
self._model_instance.credentials)
|
||||
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
|
||||
for i in range(0, len(embedding_queue_texts), max_chunks):
|
||||
@@ -61,17 +62,20 @@ class CacheEmbedding(Embeddings):
|
||||
except Exception as e:
|
||||
logging.exception('Failed transform embedding: ', e)
|
||||
cache_embeddings = []
|
||||
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
text_embeddings[i] = embedding
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider)
|
||||
embedding_cache.set_embedding(embedding)
|
||||
db.session.add(embedding_cache)
|
||||
cache_embeddings.append(hash)
|
||||
db.session.commit()
|
||||
try:
|
||||
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
text_embeddings[i] = embedding
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider)
|
||||
embedding_cache.set_embedding(embedding)
|
||||
db.session.add(embedding_cache)
|
||||
cache_embeddings.append(hash)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.error('Failed to embed documents: ', ex)
|
||||
|
||||
@@ -29,16 +29,16 @@ class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
# transform inputs to json string
|
||||
inputs_str = json.dumps(inputs, indent=4)
|
||||
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
|
||||
|
||||
# replace code and inputs
|
||||
runner = NODEJS_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, NODEJS_PRELOAD
|
||||
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
"""
|
||||
|
||||
@@ -62,10 +62,10 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
|
||||
# transform jinja2 template to python code
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4))
|
||||
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False))
|
||||
|
||||
return runner, JINJA2_PRELOAD
|
||||
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
"""
|
||||
@@ -81,4 +81,4 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
|
||||
return {
|
||||
'result': result
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ class PythonTemplateTransformer(TemplateTransformer):
|
||||
"""
|
||||
|
||||
# transform inputs to json string
|
||||
inputs_str = json.dumps(inputs, indent=4)
|
||||
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
|
||||
|
||||
# replace code and inputs
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@@ -657,18 +658,25 @@ class IndexingRunner:
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for i in range(0, len(documents), chunk_size):
|
||||
chunk_documents = documents[i:i + chunk_size]
|
||||
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
|
||||
chunk_documents, dataset,
|
||||
dataset_document, embedding_model_instance,
|
||||
embedding_model_type_instance))
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(target=self._process_keyword_index,
|
||||
args=(current_app._get_current_object(),
|
||||
dataset.id, dataset_document.id, documents))
|
||||
create_keyword_thread.start()
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for i in range(0, len(documents), chunk_size):
|
||||
chunk_documents = documents[i:i + chunk_size]
|
||||
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
|
||||
chunk_documents, dataset,
|
||||
dataset_document, embedding_model_instance,
|
||||
embedding_model_type_instance))
|
||||
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
|
||||
create_keyword_thread.join()
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
# update document status to completed
|
||||
@@ -682,6 +690,27 @@ class IndexingRunner:
|
||||
}
|
||||
)
|
||||
|
||||
def _process_keyword_index(self, flask_app, dataset_id, document_id, documents):
|
||||
with flask_app.app_context():
|
||||
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
if dataset.indexing_technique != 'high_quality':
|
||||
document_ids = [document.metadata['doc_id'] for document in documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == "indexing"
|
||||
).update({
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.utcnow()
|
||||
})
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
|
||||
embedding_model_instance, embedding_model_type_instance):
|
||||
with flask_app.app_context():
|
||||
@@ -700,7 +729,7 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
# load index
|
||||
index_processor.load(dataset, chunk_documents)
|
||||
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
||||
|
||||
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
|
||||
@@ -3,6 +3,7 @@ from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
@@ -124,7 +125,17 @@ class TokenBufferMemory:
|
||||
else:
|
||||
continue
|
||||
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
@@ -99,6 +99,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo-2024-04-09
|
||||
value: gpt-4-turbo-2024-04-09
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-0125-preview
|
||||
value: gpt-4-0125-preview
|
||||
|
||||
@@ -74,7 +74,7 @@ provider_credential_schema:
|
||||
label:
|
||||
en_US: Available Model Name
|
||||
zh_Hans: 可用模型名称
|
||||
type: secret-input
|
||||
type: text-input
|
||||
placeholder:
|
||||
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
|
||||
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1)
|
||||
|
||||
@@ -2,8 +2,6 @@ model: amazon.titan-text-express-v1
|
||||
label:
|
||||
en_US: Titan Text G1 - Express
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
|
||||
@@ -2,8 +2,6 @@ model: amazon.titan-text-lite-v1
|
||||
label:
|
||||
en_US: Titan Text G1 - Lite
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
|
||||
@@ -50,3 +50,4 @@ pricing:
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@@ -22,7 +22,7 @@ parameter_rules:
|
||||
min: 0
|
||||
max: 500
|
||||
default: 0
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
|
||||
@@ -8,9 +8,9 @@ model_properties:
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
- name: p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
@@ -19,7 +19,7 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
|
||||
@@ -402,25 +402,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
if "anthropic.claude-3" in model:
|
||||
try:
|
||||
self._invoke_claude(model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[{"role": "user", "content": "ping"}],
|
||||
model_parameters={},
|
||||
stop=None,
|
||||
stream=False)
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
required_params = {}
|
||||
if "anthropic" in model:
|
||||
required_params = {
|
||||
"max_tokens": 32,
|
||||
}
|
||||
elif "ai21" in model:
|
||||
# ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again.
|
||||
required_params = {
|
||||
"temperature": 0.7,
|
||||
"topP": 0.9,
|
||||
"maxTokens": 32,
|
||||
}
|
||||
|
||||
try:
|
||||
ping_message = UserPromptMessage(content="ping")
|
||||
self._generate(model=model,
|
||||
self._invoke(model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[ping_message],
|
||||
model_parameters={},
|
||||
model_parameters=required_params,
|
||||
stream=False)
|
||||
|
||||
except ClientError as ex:
|
||||
@@ -503,7 +503,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
if model_prefix == "amazon":
|
||||
payload["textGenerationConfig"] = { **model_parameters }
|
||||
payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else [])
|
||||
payload["textGenerationConfig"]["stopSequences"] = ["User:"]
|
||||
|
||||
payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
|
||||
@@ -513,10 +513,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
payload["maxTokens"] = model_parameters.get("maxTokens")
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
|
||||
# jurassic models only support a single stop sequence
|
||||
if stop:
|
||||
payload["stopSequences"] = stop[0]
|
||||
|
||||
if model_parameters.get("presencePenalty"):
|
||||
payload["presencePenalty"] = {model_parameters.get("presencePenalty")}
|
||||
if model_parameters.get("frequencyPenalty"):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
- command-r
|
||||
- command-r-plus
|
||||
- command-chat
|
||||
- command-light-chat
|
||||
- command-nightly-chat
|
||||
|
||||
@@ -31,7 +31,7 @@ parameter_rules:
|
||||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
|
||||
@@ -31,7 +31,7 @@ parameter_rules:
|
||||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
|
||||
@@ -31,7 +31,7 @@ parameter_rules:
|
||||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.3'
|
||||
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.3'
|
||||
|
||||
@@ -31,7 +31,7 @@ parameter_rules:
|
||||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '1.0'
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
model: command-r-plus
|
||||
label:
|
||||
en_US: command-r-plus
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '3'
|
||||
output: '15'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,45 @@
|
||||
model: command-r
|
||||
label:
|
||||
en_US: command-r
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.5'
|
||||
output: '1.5'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '1.0'
|
||||
|
||||
@@ -1,20 +1,38 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import cohere
|
||||
from cohere.responses import Chat, Generations
|
||||
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
|
||||
from cohere.responses.generation import StreamingGenerations, StreamingText
|
||||
from cohere import (
|
||||
ChatMessage,
|
||||
ChatStreamRequestToolResultsItem,
|
||||
GenerateStreamedResponse,
|
||||
GenerateStreamedResponse_StreamEnd,
|
||||
GenerateStreamedResponse_StreamError,
|
||||
GenerateStreamedResponse_TextGeneration,
|
||||
Generation,
|
||||
NonStreamedChatResponse,
|
||||
StreamedChatResponse,
|
||||
StreamedChatResponse_StreamEnd,
|
||||
StreamedChatResponse_TextGeneration,
|
||||
StreamedChatResponse_ToolCallsGeneration,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolParameterDefinitionsValue,
|
||||
)
|
||||
from cohere.core import RequestOptions
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
@@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
@@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
if stop:
|
||||
model_parameters['end_sequences'] = stop
|
||||
|
||||
response = client.generate(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
response = client.generate_stream(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
else:
|
||||
response = client.generate(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
|
||||
prompt_messages: list[PromptMessage]) \
|
||||
-> LLMResult:
|
||||
"""
|
||||
@@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = response.meta['billed_units']['input_tokens']
|
||||
completion_tokens = response.meta['billed_units']['output_tokens']
|
||||
prompt_tokens = int(response.meta.billed_units.input_tokens)
|
||||
completion_tokens = int(response.meta.billed_units.output_tokens)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
@@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return response
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
index = 1
|
||||
full_assistant_content = ''
|
||||
for chunk in response:
|
||||
if isinstance(chunk, StreamingText):
|
||||
chunk = cast(StreamingText, chunk)
|
||||
if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
|
||||
chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
|
||||
text = chunk.text
|
||||
|
||||
if text is None:
|
||||
@@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
index += 1
|
||||
elif chunk is None:
|
||||
elif isinstance(chunk, GenerateStreamedResponse_StreamEnd):
|
||||
chunk = cast(GenerateStreamedResponse_StreamEnd, chunk)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = response.meta['billed_units']['input_tokens']
|
||||
completion_tokens = response.meta['billed_units']['output_tokens']
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
completion_tokens = self._num_tokens_from_messages(
|
||||
model,
|
||||
credentials,
|
||||
[AssistantPromptMessage(content=full_assistant_content)]
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
@@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=''),
|
||||
finish_reason=response.finish_reason,
|
||||
finish_reason=chunk.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
break
|
||||
elif isinstance(chunk, GenerateStreamedResponse_StreamError):
|
||||
chunk = cast(GenerateStreamedResponse_StreamError, chunk)
|
||||
raise InvokeBadRequestError(chunk.err)
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm chat model
|
||||
@@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
@@ -282,32 +319,49 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
if user:
|
||||
model_parameters['user_name'] = user
|
||||
if stop:
|
||||
model_parameters['stop_sequences'] = stop
|
||||
|
||||
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
||||
if tools:
|
||||
if len(tools) == 1:
|
||||
raise ValueError("Cohere tool call requires at least two tools to be specified.")
|
||||
|
||||
model_parameters['tools'] = self._convert_tools(tools)
|
||||
|
||||
message, chat_histories, tool_results \
|
||||
= self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
||||
|
||||
if tool_results:
|
||||
model_parameters['tool_results'] = tool_results
|
||||
|
||||
# chat model
|
||||
real_model = model
|
||||
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
||||
real_model = model.removesuffix('-chat')
|
||||
|
||||
response = client.chat(
|
||||
message=message,
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
stream=stream,
|
||||
return_preamble=True,
|
||||
**model_parameters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
|
||||
response = client.chat_stream(
|
||||
message=message,
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
else:
|
||||
response = client.chat(
|
||||
message=message,
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
|
||||
prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
|
||||
prompt_messages: list[PromptMessage]) \
|
||||
-> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
@@ -316,14 +370,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop words
|
||||
:return: llm response
|
||||
"""
|
||||
assistant_text = response.text
|
||||
|
||||
tool_calls = []
|
||||
if response.tool_calls:
|
||||
for cohere_tool_call in response.tool_calls:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=cohere_tool_call.name,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=cohere_tool_call.name,
|
||||
arguments=json.dumps(cohere_tool_call.parameters)
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
content=assistant_text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
@@ -333,44 +400,38 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
if stop:
|
||||
# enforce stop tokens
|
||||
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.preamble
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None) -> Generator:
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Iterator[StreamedChatResponse],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
:param model: model name
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop words
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
|
||||
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
|
||||
preamble: Optional[str] = None) -> LLMResultChunk:
|
||||
def final_response(full_text: str,
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall],
|
||||
index: int,
|
||||
finish_reason: Optional[str] = None) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
|
||||
full_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_text
|
||||
content=full_text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
|
||||
|
||||
@@ -380,10 +441,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=preamble,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=''),
|
||||
message=AssistantPromptMessage(content='', tool_calls=tool_calls),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
@@ -391,9 +451,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
index = 1
|
||||
full_assistant_content = ''
|
||||
tool_calls = []
|
||||
for chunk in response:
|
||||
if isinstance(chunk, StreamTextGeneration):
|
||||
chunk = cast(StreamTextGeneration, chunk)
|
||||
if isinstance(chunk, StreamedChatResponse_TextGeneration):
|
||||
chunk = cast(StreamedChatResponse_TextGeneration, chunk)
|
||||
text = chunk.text
|
||||
|
||||
if text is None:
|
||||
@@ -404,12 +465,6 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
content=text
|
||||
)
|
||||
|
||||
# stop
|
||||
# notice: This logic can only cover few stop scenarios
|
||||
if stop and text in stop:
|
||||
yield final_response(full_assistant_content, index, 'stop')
|
||||
break
|
||||
|
||||
full_assistant_content += text
|
||||
|
||||
yield LLMResultChunk(
|
||||
@@ -422,39 +477,96 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
index += 1
|
||||
elif isinstance(chunk, StreamEnd):
|
||||
chunk = cast(StreamEnd, chunk)
|
||||
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
|
||||
elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
|
||||
chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk)
|
||||
if chunk.tool_calls:
|
||||
for cohere_tool_call in chunk.tool_calls:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=cohere_tool_call.name,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=cohere_tool_call.name,
|
||||
arguments=json.dumps(cohere_tool_call.parameters)
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif isinstance(chunk, StreamedChatResponse_StreamEnd):
|
||||
chunk = cast(StreamedChatResponse_StreamEnd, chunk)
|
||||
yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
|
||||
index += 1
|
||||
|
||||
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
|
||||
-> tuple[str, list[dict]]:
|
||||
-> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
|
||||
"""
|
||||
Convert prompt messages to message and chat histories
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
chat_histories = []
|
||||
latest_tool_call_n_outputs = []
|
||||
for prompt_message in prompt_messages:
|
||||
chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
|
||||
if prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
prompt_message = cast(AssistantPromptMessage, prompt_message)
|
||||
if prompt_message.tool_calls:
|
||||
for tool_call in prompt_message.tool_calls:
|
||||
latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
|
||||
call=ToolCall(
|
||||
name=tool_call.function.name,
|
||||
parameters=json.loads(tool_call.function.arguments)
|
||||
),
|
||||
outputs=[]
|
||||
))
|
||||
else:
|
||||
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
||||
if cohere_prompt_message:
|
||||
chat_histories.append(cohere_prompt_message)
|
||||
elif prompt_message.role == PromptMessageRole.TOOL:
|
||||
prompt_message = cast(ToolPromptMessage, prompt_message)
|
||||
if latest_tool_call_n_outputs:
|
||||
i = 0
|
||||
for tool_call_n_outputs in latest_tool_call_n_outputs:
|
||||
if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
|
||||
latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
|
||||
call=ToolCall(
|
||||
name=tool_call_n_outputs.call.name,
|
||||
parameters=tool_call_n_outputs.call.parameters
|
||||
),
|
||||
outputs=[{
|
||||
"result": prompt_message.content
|
||||
}]
|
||||
)
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
||||
if cohere_prompt_message:
|
||||
chat_histories.append(cohere_prompt_message)
|
||||
|
||||
if latest_tool_call_n_outputs:
|
||||
new_latest_tool_call_n_outputs = []
|
||||
for tool_call_n_outputs in latest_tool_call_n_outputs:
|
||||
if tool_call_n_outputs.outputs:
|
||||
new_latest_tool_call_n_outputs.append(tool_call_n_outputs)
|
||||
|
||||
latest_tool_call_n_outputs = new_latest_tool_call_n_outputs
|
||||
|
||||
# get latest message from chat histories and pop it
|
||||
if len(chat_histories) > 0:
|
||||
latest_message = chat_histories.pop()
|
||||
message = latest_message['message']
|
||||
message = latest_message.message
|
||||
else:
|
||||
raise ValueError('Prompt messages is empty')
|
||||
|
||||
return message, chat_histories
|
||||
return message, chat_histories, latest_tool_call_n_outputs
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]:
|
||||
"""
|
||||
Convert PromptMessage to dict for Cohere model
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "USER", "message": message.content}
|
||||
chat_message = ChatMessage(role="USER", message=message.content)
|
||||
else:
|
||||
sub_message_text = ''
|
||||
for message_content in message.content:
|
||||
@@ -462,20 +574,57 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_text += message_content.data
|
||||
|
||||
message_dict = {"role": "USER", "message": sub_message_text}
|
||||
chat_message = ChatMessage(role="USER", message=sub_message_text)
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "CHATBOT", "message": message.content}
|
||||
if not message.content:
|
||||
return None
|
||||
chat_message = ChatMessage(role="CHATBOT", message=message.content)
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "USER", "message": message.content}
|
||||
chat_message = ChatMessage(role="USER", message=message.content)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["user_name"] = message.name
|
||||
return chat_message
|
||||
|
||||
return message_dict
|
||||
def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]:
|
||||
"""
|
||||
Convert tools to Cohere model
|
||||
"""
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
properties = tool.parameters['properties']
|
||||
required_properties = tool.parameters['required']
|
||||
|
||||
parameter_definitions = {}
|
||||
for p_key, p_val in properties.items():
|
||||
required = False
|
||||
if property in required_properties:
|
||||
required = True
|
||||
|
||||
desc = p_val['description']
|
||||
if 'enum' in p_val:
|
||||
desc += (f"; Only accepts one of the following predefined options: "
|
||||
f"[{', '.join(p_val['enum'])}]")
|
||||
|
||||
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
|
||||
description=desc,
|
||||
type=p_val['type'],
|
||||
required=required
|
||||
)
|
||||
|
||||
cohere_tool = Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameter_definitions=parameter_definitions
|
||||
)
|
||||
|
||||
cohere_tools.append(cohere_tool)
|
||||
|
||||
return cohere_tools
|
||||
|
||||
def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
|
||||
"""
|
||||
@@ -494,12 +643,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
model=model
|
||||
)
|
||||
|
||||
return response.length
|
||||
return len(response.tokens)
|
||||
|
||||
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
|
||||
"""Calculate num tokens Cohere model."""
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
message_strs = [f"{message['role']}: {message['message']}" for message in messages]
|
||||
calc_messages = []
|
||||
for message in messages:
|
||||
cohere_message = self._convert_prompt_message_to_dict(message)
|
||||
if cohere_message:
|
||||
calc_messages.append(cohere_message)
|
||||
message_strs = [f"{message.role}: {message.message}" for message in calc_messages]
|
||||
message_str = "\n".join(message_strs)
|
||||
|
||||
real_model = model
|
||||
@@ -565,13 +718,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.CohereConnectionError
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [
|
||||
cohere.CohereAPIError,
|
||||
cohere.CohereError,
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import cohere
|
||||
from cohere.core import RequestOptions
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@@ -44,19 +45,21 @@ class CohereRerankModel(RerankModel):
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
results = client.rerank(
|
||||
response = client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
model=model,
|
||||
top_n=top_n
|
||||
top_n=top_n,
|
||||
return_documents=True,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(results):
|
||||
for idx, result in enumerate(response.results):
|
||||
# format document
|
||||
rerank_document = RerankDocument(
|
||||
index=result.index,
|
||||
text=result.document['text'],
|
||||
text=result.document.text,
|
||||
score=result.relevance_score,
|
||||
)
|
||||
|
||||
@@ -108,13 +111,21 @@ class CohereRerankModel(RerankModel):
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.CohereConnectionError,
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [
|
||||
cohere.CohereAPIError,
|
||||
cohere.CohereError,
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
import cohere
|
||||
import numpy as np
|
||||
from cohere.responses import Tokens
|
||||
from cohere.core import RequestOptions
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@@ -52,8 +52,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
text=text
|
||||
)
|
||||
|
||||
for j in range(0, tokenize_response.length, context_size):
|
||||
tokens += [tokenize_response.token_strings[j: j + context_size]]
|
||||
for j in range(0, len(tokenize_response), context_size):
|
||||
tokens += [tokenize_response[j: j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
@@ -127,9 +127,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
return response.length
|
||||
return len(response)
|
||||
|
||||
def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens:
|
||||
def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]:
|
||||
"""
|
||||
Tokenize text
|
||||
:param model: model name
|
||||
@@ -138,17 +138,19 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
if not text:
|
||||
return Tokens([], [], {})
|
||||
return []
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
response = client.tokenize(
|
||||
text=text,
|
||||
model=model
|
||||
model=model,
|
||||
offline=False,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
return response
|
||||
return response.token_strings
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@@ -184,10 +186,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
response = client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type='search_document' if len(texts) > 1 else 'search_query'
|
||||
input_type='search_document' if len(texts) > 1 else 'search_query',
|
||||
request_options=RequestOptions(max_retries=1)
|
||||
)
|
||||
|
||||
return response.embeddings, response.meta['billed_units']['input_tokens']
|
||||
return response.embeddings, int(response.meta.billed_units.input_tokens)
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
@@ -231,13 +234,21 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.CohereConnectionError
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [
|
||||
cohere.CohereAPIError,
|
||||
cohere.CohereError,
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
model: gemini-1.5-pro-latest
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,8 +1,31 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
@@ -13,6 +36,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
@@ -20,7 +44,286 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get('function_calling_type') == 'tool_call'
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
use_template='temperature',
|
||||
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
use_template='max_tokens',
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get('max_tokens', 4096)),
|
||||
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
use_template='top_p',
|
||||
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and set([
|
||||
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
|
||||
]).intersection(model_schema.features or []):
|
||||
credentials['function_calling_type'] = 'tool_call'
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API format
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
"detail": message_content.detail.value
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = []
|
||||
for function_call in message.tool_calls:
|
||||
message_dict["tool_calls"].append({
|
||||
"id": function_call.id,
|
||||
"type": function_call.type,
|
||||
"function": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments
|
||||
}
|
||||
})
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
|
||||
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
||||
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
||||
function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ''
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||
-> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
finish_reason = "Unknown"
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_name: str):
|
||||
if not tool_name:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id='',
|
||||
type='',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
)
|
||||
break
|
||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json['choices'][0]
|
||||
finish_reason = chunk_json['choices'][0].get('finish_reason')
|
||||
chunk_index += 1
|
||||
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
delta_content = delta.get('content')
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta_content
|
||||
elif 'text' in choice:
|
||||
choice_text = choice.get('text', '')
|
||||
if choice_text == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
# check payload indicator for completion
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(
|
||||
tool_calls=tools_calls,
|
||||
content=""
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
@@ -20,6 +20,7 @@ supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
@@ -30,3 +31,51 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '4096'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
default: '4096'
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not supported
|
||||
zh_Hans: 不支持
|
||||
- value: tool_call
|
||||
label:
|
||||
en_US: Tool Call
|
||||
zh_Hans: Tool Call
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
- gpt-4
|
||||
- gpt-4-turbo
|
||||
- gpt-4-turbo-2024-04-09
|
||||
- gpt-4-turbo-preview
|
||||
- gpt-4-32k
|
||||
- gpt-4-1106-preview
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
model: gpt-4-turbo-2024-04-09
|
||||
label:
|
||||
zh_Hans: gpt-4-turbo-2024-04-09
|
||||
en_US: gpt-4-turbo-2024-04-09
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: seed
|
||||
label:
|
||||
zh_Hans: 种子
|
||||
en_US: Seed
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint
|
||||
响应参数来监视变化。
|
||||
en_US: If specified, model will make a best effort to sample deterministically,
|
||||
such that repeated requests with the same seed and parameters should return
|
||||
the same result. Determinism is not guaranteed, and you should refer to the
|
||||
system_fingerprint response parameter to monitor changes in the backend.
|
||||
required: false
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.01'
|
||||
output: '0.03'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,57 @@
|
||||
model: gpt-4-turbo
|
||||
label:
|
||||
zh_Hans: gpt-4-turbo
|
||||
en_US: gpt-4-turbo
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: seed
|
||||
label:
|
||||
zh_Hans: 种子
|
||||
en_US: Seed
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint
|
||||
响应参数来监视变化。
|
||||
en_US: If specified, model will make a best effort to sample deterministically,
|
||||
such that repeated requests with the same seed and parameters should return
|
||||
the same result. Determinism is not guaranteed, and you should refer to the
|
||||
system_fingerprint response parameter to monitor changes in the backend.
|
||||
required: false
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.01'
|
||||
output: '0.03'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -547,6 +547,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# chat model
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
@@ -757,6 +760,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
|
||||
return tool_call
|
||||
|
||||
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Clear illegal prompt messages for OpenAI API
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:return: cleaned prompt messages
|
||||
"""
|
||||
checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09']
|
||||
|
||||
if model in checklist:
|
||||
# count how many user messages are there
|
||||
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
|
||||
if user_message_count > 1:
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
item.data if item.type == PromptMessageContentType.TEXT else
|
||||
'[IMAGE]' if item.type == PromptMessageContentType.IMAGE else ''
|
||||
for item in prompt_message.content
|
||||
])
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API
|
||||
|
||||
@@ -167,23 +167,27 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
support_function_call = False
|
||||
features = []
|
||||
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'function_call':
|
||||
features = [ModelFeature.TOOL_CALL]
|
||||
support_function_call = True
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
# if not endpoint_url.endswith('/'):
|
||||
# endpoint_url += '/'
|
||||
# if 'https://api.openai.com/v1/' == endpoint_url:
|
||||
# features = [ModelFeature.STREAM_TOOL_CALL]
|
||||
# features.append(ModelFeature.STREAM_TOOL_CALL)
|
||||
|
||||
vision_support = credentials.get('vision_support', 'not_support')
|
||||
if vision_support == 'support':
|
||||
features.append(ModelFeature.VISION)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features if support_function_call else [],
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||
@@ -378,13 +382,41 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_call_id: str):
|
||||
tool_call = next(
|
||||
(tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None
|
||||
)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id='',
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name='',
|
||||
arguments=''
|
||||
)
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.id)
|
||||
# update tool call
|
||||
tool_call.id = new_tool_call.id
|
||||
tool_call.type = new_tool_call.type
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
chunk_json = None
|
||||
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
@@ -405,8 +437,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
delta_content = delta.get('content')
|
||||
if delta_content is None or delta_content == '':
|
||||
continue
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
@@ -414,6 +444,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == '':
|
||||
continue
|
||||
|
||||
# function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
# tool_calls = [function_call] if function_call else []
|
||||
|
||||
@@ -437,6 +472,18 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
# check payload indicator for completion
|
||||
if finish_reason is not None:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(
|
||||
tool_calls=tools_calls,
|
||||
),
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
@@ -573,7 +620,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(self, model: str, text: str,
|
||||
def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Approximate num tokens for model with gpt2 tokenizer.
|
||||
@@ -583,7 +630,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
if isinstance(text, str):
|
||||
full_text = text
|
||||
else:
|
||||
full_text = ''
|
||||
for message_content in text:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
full_text += message_content.data
|
||||
|
||||
num_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
@@ -735,4 +791,4 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
function=function
|
||||
)
|
||||
|
||||
return tool_call
|
||||
return tool_call
|
||||
@@ -97,6 +97,25 @@ model_credential_schema:
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- variable: vision_support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
zh_Hans: Vision 支持
|
||||
en_US: Vision Support
|
||||
type: select
|
||||
required: false
|
||||
default: no_support
|
||||
options:
|
||||
- value: support
|
||||
label:
|
||||
en_US: Support
|
||||
zh_Hans: 支持
|
||||
- value: no_support
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- variable: stream_mode_delimiter
|
||||
label:
|
||||
zh_Hans: 流模式返回结果的分隔符
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
model: ernie-3.5-8k
|
||||
model: ernie-3.5-4k-0205
|
||||
label:
|
||||
en_US: Ernie-3.5-8K
|
||||
en_US: Ernie-3.5-4k-0205
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
||||
@@ -232,8 +232,8 @@ class SimplePromptTransform(PromptTransform):
|
||||
)
|
||||
),
|
||||
max_token_limit=rest_tokens,
|
||||
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
)
|
||||
|
||||
# get prompt
|
||||
|
||||
@@ -24,56 +24,67 @@ class Jieba(BaseKeyword):
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
return self
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get('keywords_list', None)
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
keywords = keywords_list[i]
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get('keywords_list', None)
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
keywords = keywords_list[i]
|
||||
if not keywords:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content,
|
||||
self._config.max_keywords_per_chunk)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
# get segment ids by document_id
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id,
|
||||
DocumentSegment.document_id == document_id
|
||||
).all()
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
# get segment ids by document_id
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id,
|
||||
DocumentSegment.document_id == document_id
|
||||
).all()
|
||||
|
||||
ids = [segment.index_node_id for segment in segments]
|
||||
ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
@@ -106,13 +117,15 @@ class Jieba(BaseKeyword):
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
if dataset_keyword_table.data_source_type != 'database':
|
||||
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
|
||||
storage.delete(file_key)
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
if dataset_keyword_table.data_source_type != 'database':
|
||||
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
|
||||
storage.delete(file_key)
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
@@ -135,33 +148,31 @@ class Jieba(BaseKeyword):
|
||||
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
|
||||
|
||||
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return keyword_table_dict['__data__']['table']
|
||||
else:
|
||||
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table='',
|
||||
data_source_type=keyword_data_source_type,
|
||||
)
|
||||
if keyword_data_source_type == 'database':
|
||||
dataset_keyword_table.keyword_table = json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return keyword_table_dict['__data__']['table']
|
||||
else:
|
||||
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table='',
|
||||
data_source_type=keyword_data_source_type,
|
||||
)
|
||||
if keyword_data_source_type == 'database':
|
||||
dataset_keyword_table.keyword_table = json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
|
||||
for keyword in keywords:
|
||||
|
||||
@@ -20,16 +20,17 @@ class MilvusConfig(BaseModel):
|
||||
password: str
|
||||
secure: bool = False
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values.get('host'):
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
if not values['port']:
|
||||
if not values.get('port'):
|
||||
raise ValueError("config MILVUS_PORT is required")
|
||||
if not values['user']:
|
||||
if not values.get('user'):
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
if not values.get('password'):
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
@@ -39,7 +40,8 @@ class MilvusConfig(BaseModel):
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'secure': self.secure
|
||||
'secure': self.secure,
|
||||
'db_name': self.database,
|
||||
}
|
||||
|
||||
|
||||
@@ -128,7 +130,8 @@ class MilvusVector(BaseVector):
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
|
||||
db_name=self._client_config.database)
|
||||
|
||||
from pymilvus import utility
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
@@ -140,7 +143,8 @@ class MilvusVector(BaseVector):
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
|
||||
db_name=self._client_config.database)
|
||||
|
||||
from pymilvus import utility
|
||||
if not utility.has_collection(self._collection_name, using=alias):
|
||||
@@ -192,7 +196,7 @@ class MilvusVector(BaseVector):
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
|
||||
password=self._client_config.password)
|
||||
password=self._client_config.password, db_name=self._client_config.database)
|
||||
if not utility.has_collection(self._collection_name, using=alias):
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
|
||||
@@ -110,6 +110,7 @@ class Vector:
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
database=config.get('MILVUS_DATABASE'),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -34,6 +34,7 @@ class CSVExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
docs = []
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
|
||||
|
||||
|
||||
class FakeLLM(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
response: str
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
return self.response
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"response": self.response}
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return 0
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if self.streaming:
|
||||
for token in output_str:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
time.sleep(0.01)
|
||||
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
llm_output = {"token_usage": {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
}}
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
@@ -1,46 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain import LLMChain as LCLLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.retrieval.agent.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
model_config: ModelConfigWithCredentialsEntity
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
parameters: dict[str, Any] = {}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: list[dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
messages = prompts[0].to_messages()
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
stream=False,
|
||||
stop=stop,
|
||||
model_parameters=self.parameters
|
||||
)
|
||||
|
||||
generations = [
|
||||
[Generation(text=result.message.content)]
|
||||
]
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
@@ -1,179 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.rag.retrieval.agent.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
model_config: ModelConfigWithCredentialsEntity
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
# output = ''
|
||||
# rst_json = json.loads(rst)
|
||||
# for item in rst_json:
|
||||
# output += f'{item["content"]}\n'
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
return agent_decision
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,259 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.rag.retrieval.agent.llm_chain import LLMChain
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
dataset_tools: Sequence[BaseTool]
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.dataset_tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.dataset_tools) == 1:
|
||||
tool = next(iter(self.dataset_tools))
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
elif isinstance(tool_inputs, str):
|
||||
agent_decision.tool_input = kwargs['input']
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
unique_tool_names = set(tool.name for tool in tools)
|
||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
dataset_tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||
from core.helper import moderation
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
model_config: ModelConfigWithCredentialsEntity
|
||||
tools: list[BaseTool]
|
||||
summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
memory: Optional[TokenBufferMemory] = None
|
||||
callbacks: Callbacks = None
|
||||
max_iterations: int = 6
|
||||
max_execution_time: Optional[float] = None
|
||||
early_stopping_method: str = "generate"
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentExecuteResult(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
output: Optional[str]
|
||||
configuration: AgentConfiguration
|
||||
|
||||
|
||||
class AgentExecutor:
|
||||
def __init__(self, configuration: AgentConfiguration):
|
||||
self.configuration = configuration
|
||||
self.agent = self._init_agent()
|
||||
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools
|
||||
if isinstance(t, DatasetRetrieverTool)
|
||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||
if self.configuration.memory else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools
|
||||
if isinstance(t, DatasetRetrieverTool)
|
||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
verbose=True
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
|
||||
|
||||
return agent
|
||||
|
||||
def should_use_agent(self, query: str) -> bool:
|
||||
return self.agent.should_use_agent(query)
|
||||
|
||||
def run(self, query: str) -> AgentExecuteResult:
|
||||
moderation_result = moderation.check_moderation(
|
||||
self.configuration.model_config,
|
||||
query
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
return AgentExecuteResult(
|
||||
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
strategy=self.configuration.strategy,
|
||||
configuration=self.configuration
|
||||
)
|
||||
|
||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||
agent=self.agent,
|
||||
tools=self.configuration.tools,
|
||||
max_iterations=self.configuration.max_iterations,
|
||||
max_execution_time=self.configuration.max_execution_time,
|
||||
early_stopping_method=self.configuration.early_stopping_method,
|
||||
callbacks=self.configuration.callbacks
|
||||
)
|
||||
|
||||
try:
|
||||
output = agent_executor.run(input=query)
|
||||
except InvokeError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logging.exception("agent_executor run failed")
|
||||
output = None
|
||||
|
||||
return AgentExecuteResult(
|
||||
output=output,
|
||||
strategy=self.configuration.strategy,
|
||||
configuration=self.configuration
|
||||
)
|
||||
@@ -1,5 +1,7 @@
|
||||
import threading
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
||||
@@ -7,17 +9,35 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieval:
|
||||
def retrieve(self, tenant_id: str,
|
||||
def retrieve(self, app_id: str, user_id: str, tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
@@ -27,6 +47,8 @@ class DatasetRetrieval:
|
||||
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
:param app_id: app_id
|
||||
:param user_id: user_id
|
||||
:param tenant_id: tenant id
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
@@ -38,12 +60,22 @@ class DatasetRetrieval:
|
||||
:return:
|
||||
"""
|
||||
dataset_ids = config.dataset_ids
|
||||
if len(dataset_ids) == 0:
|
||||
return None
|
||||
retrieve_config = config.retrieve_config
|
||||
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model
|
||||
)
|
||||
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
@@ -59,38 +91,291 @@ class DatasetRetrieval:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
dataset_retriever_tools = self.to_dataset_retriever_tool(
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
all_documents = []
|
||||
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query,
|
||||
model_instance,
|
||||
model_config, planning_strategy)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from,
|
||||
available_datasets, query, retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
retrieve_config.reranking_model.get('reranking_model_name'))
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
if show_retrieve_source:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=segment.dataset_id
|
||||
).first()
|
||||
document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': invoke_from.to_source(),
|
||||
'score': document_score_list.get(segment.index_node_id, None)
|
||||
}
|
||||
|
||||
if invoke_from.to_source() == 'dev':
|
||||
source['hit_count'] = segment.hit_count
|
||||
source['word_count'] = segment.word_count
|
||||
source['segment_position'] = segment.position
|
||||
source['index_node_hash'] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
if hit_callback:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ''
|
||||
|
||||
def single_retrieve(self, app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
message_tool = PromptMessageTool(
|
||||
name=dataset.id,
|
||||
description=description,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
tools.append(message_tool)
|
||||
dataset_id = None
|
||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||
dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
|
||||
user_id, tenant_id)
|
||||
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
function_call_router = FunctionCallMultiDatasetRouter()
|
||||
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrival_method = 'keyword_search'
|
||||
else:
|
||||
retrival_method = retrieval_model_config['search_method']
|
||||
# get reranking model
|
||||
reranking_model = retrieval_model_config['reranking_model'] \
|
||||
if retrieval_model_config['reranking_enable'] else None
|
||||
# get score threshold
|
||||
score_threshold = .0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
if results:
|
||||
self._on_retrival_end(results)
|
||||
return results
|
||||
return []
|
||||
|
||||
def multiple_retrieve(self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_provider_name: str,
|
||||
reranking_model_name: str):
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
for dataset in available_datasets:
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset.id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'all_documents': all_documents,
|
||||
})
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
return_resource=show_retrieve_source,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback
|
||||
provider=reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model_name
|
||||
)
|
||||
|
||||
if len(dataset_retriever_tools) == 0:
|
||||
return None
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents,
|
||||
score_threshold,
|
||||
top_k)
|
||||
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
||||
if all_documents:
|
||||
self._on_retrival_end(all_documents)
|
||||
return all_documents
|
||||
|
||||
agent_configuration = AgentConfiguration(
|
||||
strategy=planning_strategy,
|
||||
model_config=model_config,
|
||||
tools=dataset_retriever_tools,
|
||||
memory=memory,
|
||||
max_iterations=10,
|
||||
max_execution_time=400.0,
|
||||
early_stopping_method="generate"
|
||||
)
|
||||
def _on_retrival_end(self, documents: list[Document]) -> None:
|
||||
"""Handle retrival end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
)
|
||||
|
||||
agent_executor = AgentExecutor(agent_configuration)
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if not should_use_agent:
|
||||
return None
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
result = agent_executor.run(query)
|
||||
db.session.commit()
|
||||
|
||||
return result.output
|
||||
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
if not query:
|
||||
return
|
||||
for dataset_id in dataset_ids:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by=user_id
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
|
||||
def to_dataset_retriever_tool(self, tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
|
||||
@@ -12,8 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
|
||||
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
@@ -55,11 +54,10 @@ class ReactMultiDatasetRouter:
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
tenant_id: str
|
||||
|
||||
) -> Union[str, None]:
|
||||
"""Given input, decided what to do.
|
||||
@@ -72,7 +70,8 @@ class ReactMultiDatasetRouter:
|
||||
return dataset_tools[0].name
|
||||
|
||||
try:
|
||||
return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance,
|
||||
return self._react_invoke(query=query, model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
return None
|
||||
@@ -80,7 +79,6 @@ class ReactMultiDatasetRouter:
|
||||
def _react_invoke(
|
||||
self,
|
||||
query: str,
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
@@ -121,7 +119,7 @@ class ReactMultiDatasetRouter:
|
||||
model_config=model_config
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data=node_data,
|
||||
completion_param=model_config.parameters,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@@ -134,10 +132,11 @@ class ReactMultiDatasetRouter:
|
||||
return agent_decision.tool
|
||||
return None
|
||||
|
||||
def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData,
|
||||
def _invoke_llm(self, completion_param: dict,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]:
|
||||
stop: list[str], user_id: str, tenant_id: str
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data: node data
|
||||
@@ -148,7 +147,7 @@ class ReactMultiDatasetRouter:
|
||||
"""
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data.single_retrieval_config.model.completion_params,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
@@ -203,7 +202,8 @@ class ReactMultiDatasetRouter:
|
||||
) -> list[ChatModelMessage]:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
|
||||
tool_strings.append(
|
||||
f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
unique_tool_names = set(tool.name for tool in tools)
|
||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
||||
@@ -38,8 +38,10 @@ Action:
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
{{historic_messages}}
|
||||
Question: {{query}}
|
||||
Thought: {{agent_scratchpad}}"""
|
||||
{{agent_scratchpad}}
|
||||
Thought:"""
|
||||
|
||||
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
||||
Thought:"""
|
||||
|
||||
@@ -1,11 +1,92 @@
|
||||
from typing import Any
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.utilities import ArxivAPIWrapper
|
||||
import arxiv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class ArxivAPIWrapper(BaseModel):
|
||||
"""Wrapper around ArxivAPI.
|
||||
|
||||
To use, you should have the ``arxiv`` python package installed.
|
||||
https://lukasschwab.me/arxiv.py/index.html
|
||||
This wrapper will use the Arxiv API to conduct searches and
|
||||
fetch document summaries. By default, it will return the document summaries
|
||||
of the top-k results.
|
||||
It limits the Document content by doc_content_chars_max.
|
||||
Set doc_content_chars_max=None if you don't want to limit the content size.
|
||||
|
||||
Args:
|
||||
top_k_results: number of the top-scored document used for the arxiv tool
|
||||
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
|
||||
load_max_docs: a limit to the number of loaded documents
|
||||
load_all_available_meta:
|
||||
if True: the `metadata` of the loaded Documents contains all available
|
||||
meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
|
||||
if False: the `metadata` contains only the published date, title,
|
||||
authors and summary.
|
||||
doc_content_chars_max: an optional cut limit for the length of a document's
|
||||
content
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
arxiv = ArxivAPIWrapper(
|
||||
top_k_results = 3,
|
||||
ARXIV_MAX_QUERY_LENGTH = 300,
|
||||
load_max_docs = 3,
|
||||
load_all_available_meta = False,
|
||||
doc_content_chars_max = 40000
|
||||
)
|
||||
arxiv.run("tree of thought llm)
|
||||
"""
|
||||
|
||||
arxiv_search = arxiv.Search #: :meta private:
|
||||
arxiv_exceptions = (
|
||||
arxiv.ArxivError,
|
||||
arxiv.UnexpectedEmptyPageError,
|
||||
arxiv.HTTPError,
|
||||
) # :meta private:
|
||||
top_k_results: int = 3
|
||||
ARXIV_MAX_QUERY_LENGTH = 300
|
||||
load_max_docs: int = 100
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: Optional[int] = 4000
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""
|
||||
Performs an arxiv search and A single string
|
||||
with the publish date, title, authors, and summary
|
||||
for each article separated by two newlines.
|
||||
|
||||
If an error occurs or no documents found, error text
|
||||
is returned instead. Wrapper for
|
||||
https://lukasschwab.me/arxiv.py/index.html#Search
|
||||
|
||||
Args:
|
||||
query: a plaintext search query
|
||||
""" # noqa: E501
|
||||
try:
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
except self.arxiv_exceptions as ex:
|
||||
return f"Arxiv exception: {ex}"
|
||||
docs = [
|
||||
f"Published: {result.updated.date()}\n"
|
||||
f"Title: {result.title}\n"
|
||||
f"Authors: {', '.join(a.name for a in result.authors)}\n"
|
||||
f"Summary: {result.summary}"
|
||||
for result in results
|
||||
]
|
||||
if docs:
|
||||
return "\n\n".join(docs)[: self.doc_content_chars_max]
|
||||
else:
|
||||
return "No good Arxiv Result was found"
|
||||
|
||||
|
||||
class ArxivSearchInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
@@ -12,6 +12,7 @@ class BingSearchTool(BuiltinTool):
|
||||
|
||||
def _invoke_bing(self,
|
||||
user_id: str,
|
||||
server_url: str,
|
||||
subscription_key: str, query: str, limit: int,
|
||||
result_type: str, market: str, lang: str,
|
||||
filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
@@ -26,7 +27,7 @@ class BingSearchTool(BuiltinTool):
|
||||
}
|
||||
|
||||
query = quote(query)
|
||||
server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
response = get(server_url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -136,6 +137,7 @@ class BingSearchTool(BuiltinTool):
|
||||
|
||||
self._invoke_bing(
|
||||
user_id='test',
|
||||
server_url=server_url,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
@@ -188,6 +190,7 @@ class BingSearchTool(BuiltinTool):
|
||||
|
||||
return self._invoke_bing(
|
||||
user_id=user_id,
|
||||
server_url=server_url,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
|
||||
@@ -1,11 +1,95 @@
|
||||
from typing import Any
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BraveSearch
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BraveSearchWrapper(BaseModel):
|
||||
"""Wrapper around the Brave search engine."""
|
||||
|
||||
api_key: str
|
||||
"""The API key to use for the Brave search engine."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the search request."""
|
||||
base_url = "https://api.search.brave.com/res/v1/web/search"
|
||||
"""The base URL for the Brave search engine."""
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Query the Brave search engine and return the results as a JSON string.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns: The results as a JSON string.
|
||||
|
||||
"""
|
||||
web_search_results = self._search_request(query=query)
|
||||
final_results = [
|
||||
{
|
||||
"title": item.get("title"),
|
||||
"link": item.get("url"),
|
||||
"snippet": item.get("description"),
|
||||
}
|
||||
for item in web_search_results
|
||||
]
|
||||
return json.dumps(final_results)
|
||||
|
||||
def _search_request(self, query: str) -> list[dict]:
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
req = requests.PreparedRequest()
|
||||
params = {**self.search_kwargs, **{"q": query}}
|
||||
req.prepare_url(self.base_url, params)
|
||||
if req.url is None:
|
||||
raise ValueError("prepared url is None, this should not happen")
|
||||
|
||||
response = requests.get(req.url, headers=headers)
|
||||
if not response.ok:
|
||||
raise Exception(f"HTTP error {response.status_code}")
|
||||
|
||||
return response.json().get("web", {}).get("results", [])
|
||||
|
||||
class BraveSearch(BaseModel):
|
||||
"""Tool that queries the BraveSearch."""
|
||||
|
||||
name = "brave_search"
|
||||
description = (
|
||||
"a search engine. "
|
||||
"useful for when you need to answer questions about current events."
|
||||
" input should be a search query."
|
||||
)
|
||||
search_wrapper: BraveSearchWrapper
|
||||
|
||||
@classmethod
|
||||
def from_api_key(
|
||||
cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any
|
||||
) -> "BraveSearch":
|
||||
"""Create a tool from an api key.
|
||||
|
||||
Args:
|
||||
api_key: The api key to use.
|
||||
search_kwargs: Any additional kwargs to pass to the search wrapper.
|
||||
**kwargs: Any additional kwargs to pass to the tool.
|
||||
|
||||
Returns:
|
||||
A tool.
|
||||
"""
|
||||
wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {})
|
||||
return cls(search_wrapper=wrapper, **kwargs)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return self.search_wrapper.run(query)
|
||||
|
||||
class BraveSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using Brave search engine.
|
||||
@@ -31,7 +115,7 @@ class BraveSearchTool(BuiltinTool):
|
||||
|
||||
tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count})
|
||||
|
||||
results = tool.run(query)
|
||||
results = tool._run(query)
|
||||
|
||||
if not results:
|
||||
return self.create_text_message(f"No results found for '{query}' in Tavily")
|
||||
|
||||
@@ -1,16 +1,147 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import DuckDuckGoSearchRun
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DuckDuckGoSearchAPIWrapper(BaseModel):
|
||||
"""Wrapper for DuckDuckGo Search API.
|
||||
|
||||
Free and does not require any setup.
|
||||
"""
|
||||
|
||||
region: Optional[str] = "wt-wt"
|
||||
safesearch: str = "moderate"
|
||||
time: Optional[str] = "y"
|
||||
max_results: int = 5
|
||||
|
||||
def get_snippets(self, query: str) -> list[str]:
|
||||
"""Run query through DuckDuckGo and return concatenated results."""
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
timelimit=self.time,
|
||||
)
|
||||
if results is None:
|
||||
return ["No good DuckDuckGo Search Result was found"]
|
||||
snippets = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
snippets.append(res["body"])
|
||||
if len(snippets) == self.max_results:
|
||||
break
|
||||
return snippets
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
snippets = self.get_snippets(query)
|
||||
return " ".join(snippets)
|
||||
|
||||
def results(
|
||||
self, query: str, num_results: int, backend: str = "api"
|
||||
) -> list[dict[str, str]]:
|
||||
"""Run query through DuckDuckGo and return metadata.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
num_results: The number of results to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries with the following keys:
|
||||
snippet - The description of the result.
|
||||
title - The title of the result.
|
||||
link - The link to the result.
|
||||
"""
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
timelimit=self.time,
|
||||
backend=backend,
|
||||
)
|
||||
if results is None:
|
||||
return [{"Result": "No good DuckDuckGo Search Result was found"}]
|
||||
|
||||
def to_metadata(result: dict) -> dict[str, str]:
|
||||
if backend == "news":
|
||||
return {
|
||||
"date": result["date"],
|
||||
"title": result["title"],
|
||||
"snippet": result["body"],
|
||||
"source": result["source"],
|
||||
"link": result["url"],
|
||||
}
|
||||
return {
|
||||
"snippet": result["body"],
|
||||
"title": result["title"],
|
||||
"link": result["href"],
|
||||
}
|
||||
|
||||
formatted_results = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
formatted_results.append(to_metadata(res))
|
||||
if len(formatted_results) == num_results:
|
||||
break
|
||||
return formatted_results
|
||||
|
||||
|
||||
class DuckDuckGoSearchRun(BaseModel):
|
||||
"""Tool that queries the DuckDuckGo search API."""
|
||||
|
||||
name = "duckduckgo_search"
|
||||
description = (
|
||||
"A wrapper around DuckDuckGo Search. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
|
||||
default_factory=DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return self.api_wrapper.run(query)
|
||||
|
||||
|
||||
class DuckDuckGoSearchResults(BaseModel):
|
||||
"""Tool that queries the DuckDuckGo search API and gets back json."""
|
||||
|
||||
name = "DuckDuckGo Results JSON"
|
||||
description = (
|
||||
"A wrapper around Duck Duck Go Search. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query. Output is a JSON array of the query results"
|
||||
)
|
||||
num_results: int = 4
|
||||
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
|
||||
default_factory=DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
backend: str = "api"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
res = self.api_wrapper.results(query, self.num_results, backend=self.backend)
|
||||
res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
|
||||
return ", ".join([f"[{rs}]" for rs in res_strs])
|
||||
|
||||
class DuckDuckGoInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
|
||||
class DuckDuckGoSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using DuckDuckGo search engine.
|
||||
@@ -34,7 +165,7 @@ class DuckDuckGoSearchTool(BuiltinTool):
|
||||
|
||||
tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput)
|
||||
|
||||
result = tool.run(query)
|
||||
result = tool._run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
|
||||
@@ -70,43 +70,44 @@ class SerpAPI:
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
|
||||
if typ == "text":
|
||||
toret = ""
|
||||
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
|
||||
res["answer_box"] = res["answer_box"][0]
|
||||
res["answer_box"] = res["answer_box"][0] + "\n"
|
||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
toret += res["answer_box"]["answer"] + "\n"
|
||||
if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret += res["answer_box"]["snippet"] + "\n"
|
||||
if (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
for item in res["answer_box"]["snippet_highlighted_words"]:
|
||||
toret += item + "\n"
|
||||
if (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
toret += res["sports_results"]["game_spotlight"] + "\n"
|
||||
if (
|
||||
"shopping_results" in res.keys()
|
||||
and "title" in res["shopping_results"][0].keys()
|
||||
):
|
||||
toret = res["shopping_results"][:3]
|
||||
elif (
|
||||
toret += res["shopping_results"][:3] + "\n"
|
||||
if (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
elif "link" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["link"]
|
||||
elif (
|
||||
toret = res["knowledge_graph"]["description"] + "\n"
|
||||
if "snippet" in res["organic_results"][0].keys():
|
||||
for item in res["organic_results"]:
|
||||
toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n"
|
||||
if (
|
||||
"images_results" in res.keys()
|
||||
and "thumbnail" in res["images_results"][0].keys()
|
||||
):
|
||||
thumbnails = [item["thumbnail"] for item in res["images_results"][:10]]
|
||||
toret = thumbnails
|
||||
else:
|
||||
if toret == "":
|
||||
toret = "No good search result found"
|
||||
elif typ == "link":
|
||||
if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \
|
||||
|
||||
@@ -1,16 +1,187 @@
|
||||
import json
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import PubmedQueryRun
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PubMedAPIWrapper(BaseModel):
|
||||
"""
|
||||
Wrapper around PubMed API.
|
||||
|
||||
This wrapper will use the PubMed API to conduct searches and fetch
|
||||
document summaries. By default, it will return the document summaries
|
||||
of the top-k results of an input search.
|
||||
|
||||
Parameters:
|
||||
top_k_results: number of the top-scored document used for the PubMed tool
|
||||
load_max_docs: a limit to the number of loaded documents
|
||||
load_all_available_meta:
|
||||
if True: the `metadata` of the loaded Documents gets all available meta info
|
||||
(see https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch)
|
||||
if False: the `metadata` gets only the most informative fields.
|
||||
"""
|
||||
|
||||
base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
|
||||
base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
|
||||
max_retry = 5
|
||||
sleep_time = 0.2
|
||||
|
||||
# Default values for the parameters
|
||||
top_k_results: int = 3
|
||||
load_max_docs: int = 25
|
||||
ARXIV_MAX_QUERY_LENGTH = 300
|
||||
doc_content_chars_max: int = 2000
|
||||
load_all_available_meta: bool = False
|
||||
email: str = "your_email@example.com"
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""
|
||||
Run PubMed search and get the article meta information.
|
||||
See https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch
|
||||
It uses only the most informative fields of article meta information.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Retrieve the top-k results for the query
|
||||
docs = [
|
||||
f"Published: {result['pub_date']}\nTitle: {result['title']}\n"
|
||||
f"Summary: {result['summary']}"
|
||||
for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH])
|
||||
]
|
||||
|
||||
# Join the results and limit the character count
|
||||
return (
|
||||
"\n\n".join(docs)[:self.doc_content_chars_max]
|
||||
if docs
|
||||
else "No good PubMed Result was found"
|
||||
)
|
||||
except Exception as ex:
|
||||
return f"PubMed exception: {ex}"
|
||||
|
||||
def load(self, query: str) -> list[dict]:
|
||||
"""
|
||||
Search PubMed for documents matching the query.
|
||||
Return a list of dictionaries containing the document metadata.
|
||||
"""
|
||||
|
||||
url = (
|
||||
self.base_url_esearch
|
||||
+ "db=pubmed&term="
|
||||
+ str({urllib.parse.quote(query)})
|
||||
+ f"&retmode=json&retmax={self.top_k_results}&usehistory=y"
|
||||
)
|
||||
result = urllib.request.urlopen(url)
|
||||
text = result.read().decode("utf-8")
|
||||
json_text = json.loads(text)
|
||||
|
||||
articles = []
|
||||
webenv = json_text["esearchresult"]["webenv"]
|
||||
for uid in json_text["esearchresult"]["idlist"]:
|
||||
article = self.retrieve_article(uid, webenv)
|
||||
articles.append(article)
|
||||
|
||||
# Convert the list of articles to a JSON string
|
||||
return articles
|
||||
|
||||
def retrieve_article(self, uid: str, webenv: str) -> dict:
|
||||
url = (
|
||||
self.base_url_efetch
|
||||
+ "db=pubmed&retmode=xml&id="
|
||||
+ uid
|
||||
+ "&webenv="
|
||||
+ webenv
|
||||
)
|
||||
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
result = urllib.request.urlopen(url)
|
||||
break
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 429 and retry < self.max_retry:
|
||||
# Too Many Requests error
|
||||
# wait for an exponentially increasing amount of time
|
||||
print(
|
||||
f"Too Many Requests, "
|
||||
f"waiting for {self.sleep_time:.2f} seconds..."
|
||||
)
|
||||
time.sleep(self.sleep_time)
|
||||
self.sleep_time *= 2
|
||||
retry += 1
|
||||
else:
|
||||
raise e
|
||||
|
||||
xml_text = result.read().decode("utf-8")
|
||||
|
||||
# Get title
|
||||
title = ""
|
||||
if "<ArticleTitle>" in xml_text and "</ArticleTitle>" in xml_text:
|
||||
start_tag = "<ArticleTitle>"
|
||||
end_tag = "</ArticleTitle>"
|
||||
title = xml_text[
|
||||
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
|
||||
]
|
||||
|
||||
# Get abstract
|
||||
abstract = ""
|
||||
if "<AbstractText>" in xml_text and "</AbstractText>" in xml_text:
|
||||
start_tag = "<AbstractText>"
|
||||
end_tag = "</AbstractText>"
|
||||
abstract = xml_text[
|
||||
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
|
||||
]
|
||||
|
||||
# Get publication date
|
||||
pub_date = ""
|
||||
if "<PubDate>" in xml_text and "</PubDate>" in xml_text:
|
||||
start_tag = "<PubDate>"
|
||||
end_tag = "</PubDate>"
|
||||
pub_date = xml_text[
|
||||
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
|
||||
]
|
||||
|
||||
# Return article as dictionary
|
||||
article = {
|
||||
"uid": uid,
|
||||
"title": title,
|
||||
"summary": abstract,
|
||||
"pub_date": pub_date,
|
||||
}
|
||||
return article
|
||||
|
||||
|
||||
class PubmedQueryRun(BaseModel):
|
||||
"""Tool that searches the PubMed API."""
|
||||
|
||||
name = "PubMed"
|
||||
description = (
|
||||
"A wrapper around PubMed.org "
|
||||
"Useful for when you need to answer questions about Physics, Mathematics, "
|
||||
"Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
|
||||
"Electrical Engineering, and Economics "
|
||||
"from scientific articles on PubMed.org. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the Arxiv tool."""
|
||||
return self.api_wrapper.run(query)
|
||||
|
||||
|
||||
class PubMedInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
|
||||
class PubMedSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using PubMed search engine.
|
||||
@@ -34,7 +205,7 @@ class PubMedSearchTool(BuiltinTool):
|
||||
|
||||
tool = PubmedQueryRun(args_schema=PubMedInput)
|
||||
|
||||
result = tool.run(query)
|
||||
result = tool._run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
|
||||
@@ -1,11 +1,81 @@
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.utilities import TwilioAPIWrapper
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class TwilioAPIWrapper(BaseModel):
|
||||
"""Messaging Client using Twilio.
|
||||
|
||||
To use, you should have the ``twilio`` python package installed,
|
||||
and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and
|
||||
``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as
|
||||
named parameters to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.utilities.twilio import TwilioAPIWrapper
|
||||
twilio = TwilioAPIWrapper(
|
||||
account_sid="ACxxx",
|
||||
auth_token="xxx",
|
||||
from_number="+10123456789"
|
||||
)
|
||||
twilio.run('test', '+12484345508')
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
account_sid: Optional[str] = None
|
||||
"""Twilio account string identifier."""
|
||||
auth_token: Optional[str] = None
|
||||
"""Twilio auth token."""
|
||||
from_number: Optional[str] = None
|
||||
"""A Twilio phone number in [E.164](https://www.twilio.com/docs/glossary/what-e164)
|
||||
format, an
|
||||
[alphanumeric sender ID](https://www.twilio.com/docs/sms/send-messages#use-an-alphanumeric-sender-id),
|
||||
or a [Channel Endpoint address](https://www.twilio.com/docs/sms/channels#channel-addresses)
|
||||
that is enabled for the type of message you want to send. Phone numbers or
|
||||
[short codes](https://www.twilio.com/docs/sms/api/short-code) purchased from
|
||||
Twilio also work here. You cannot, for example, spoof messages from a private
|
||||
cell phone number. If you are using `messaging_service_sid`, this parameter
|
||||
must be empty.
|
||||
""" # noqa: E501
|
||||
|
||||
@validator("client", pre=True, always=True)
|
||||
def set_validator(cls, values: dict) -> dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
from twilio.rest import Client
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import twilio python package. "
|
||||
"Please install it with `pip install twilio`."
|
||||
)
|
||||
account_sid = values.get("account_sid")
|
||||
auth_token = values.get("auth_token")
|
||||
values["from_number"] = values.get("from_number")
|
||||
values["client"] = Client(account_sid, auth_token)
|
||||
|
||||
return values
|
||||
|
||||
def run(self, body: str, to: str) -> str:
|
||||
"""Run body through Twilio and respond with message sid.
|
||||
|
||||
Args:
|
||||
body: The text of the message you want to send. Can be up to 1,600
|
||||
characters in length.
|
||||
to: The destination phone number in
|
||||
[E.164](https://www.twilio.com/docs/glossary/what-e164) format for
|
||||
SMS/MMS or
|
||||
[Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
|
||||
for other 3rd-party channels.
|
||||
""" # noqa: E501
|
||||
message = self.client.messages.create(to, from_=self.from_number, body=body)
|
||||
return message.sid
|
||||
|
||||
|
||||
class SendMessageTool(BuiltinTool):
|
||||
"""
|
||||
A tool for sending messages using Twilio API.
|
||||
|
||||
@@ -1,16 +1,79 @@
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.tools import WikipediaQueryRun
|
||||
from pydantic import BaseModel, Field
|
||||
import wikipedia
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
WIKIPEDIA_MAX_QUERY_LENGTH = 300
|
||||
|
||||
class WikipediaInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
class WikipediaAPIWrapper:
|
||||
"""Wrapper around WikipediaAPI.
|
||||
|
||||
To use, you should have the ``wikipedia`` python package installed.
|
||||
This wrapper will use the Wikipedia API to conduct searches and
|
||||
fetch page summaries. By default, it will return the page summaries
|
||||
of the top-k results.
|
||||
It limits the Document content by doc_content_chars_max.
|
||||
"""
|
||||
|
||||
top_k_results: int = 3
|
||||
lang: str = "en"
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: int = 4000
|
||||
|
||||
def __init__(self, doc_content_chars_max: int = 4000):
|
||||
self.doc_content_chars_max = doc_content_chars_max
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
wikipedia.set_lang(self.lang)
|
||||
wiki_client = wikipedia
|
||||
|
||||
"""Run Wikipedia search and get page summaries."""
|
||||
page_titles = wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])
|
||||
summaries = []
|
||||
for page_title in page_titles[: self.top_k_results]:
|
||||
if wiki_page := self._fetch_page(page_title):
|
||||
if summary := self._formatted_page_summary(page_title, wiki_page):
|
||||
summaries.append(summary)
|
||||
if not summaries:
|
||||
return "No good Wikipedia Search Result was found"
|
||||
return "\n\n".join(summaries)[: self.doc_content_chars_max]
|
||||
|
||||
@staticmethod
|
||||
def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]:
|
||||
return f"Page: {page_title}\nSummary: {wiki_page.summary}"
|
||||
|
||||
def _fetch_page(self, page: str) -> Optional[str]:
|
||||
try:
|
||||
return wikipedia.page(title=page, auto_suggest=False)
|
||||
except (
|
||||
wikipedia.exceptions.PageError,
|
||||
wikipedia.exceptions.DisambiguationError,
|
||||
):
|
||||
return None
|
||||
|
||||
class WikipediaQueryRun:
|
||||
"""Tool that searches the Wikipedia API."""
|
||||
|
||||
name = "Wikipedia"
|
||||
description = (
|
||||
"A wrapper around Wikipedia. "
|
||||
"Useful for when you need to answer general questions about "
|
||||
"people, places, companies, facts, historical events, or other subjects. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: WikipediaAPIWrapper
|
||||
|
||||
def __init__(self, api_wrapper: WikipediaAPIWrapper):
|
||||
self.api_wrapper = api_wrapper
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the Wikipedia tool."""
|
||||
return self.api_wrapper.run(query)
|
||||
class WikiPediaSearchTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
@@ -24,14 +87,10 @@ class WikiPediaSearchTool(BuiltinTool):
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
tool = WikipediaQueryRun(
|
||||
name="wikipedia",
|
||||
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||
args_schema=WikipediaInput
|
||||
)
|
||||
|
||||
result = tool.run(tool_input={
|
||||
'query': query
|
||||
})
|
||||
result = tool._run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id,content=result))
|
||||
|
||||
@@ -22,9 +22,6 @@ class ValueType(Enum):
|
||||
|
||||
|
||||
class VariablePool:
|
||||
variables_mapping = {}
|
||||
user_inputs: dict
|
||||
system_variables: dict[SystemVariable, Any]
|
||||
|
||||
def __init__(self, system_variables: dict[SystemVariable, Any],
|
||||
user_inputs: dict) -> None:
|
||||
@@ -34,6 +31,7 @@ class VariablePool:
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
self.variables_mapping = {}
|
||||
self.user_inputs = user_inputs
|
||||
self.system_variables = system_variables
|
||||
for system_variable, value in system_variables.items():
|
||||
|
||||
@@ -234,6 +234,9 @@ class CodeNode(BaseNode):
|
||||
parameters_validated = {}
|
||||
for output_name, output_config in output_schema.items():
|
||||
dot = '.' if prefix else ''
|
||||
if output_name not in result:
|
||||
raise ValueError(f'Output {prefix}{dot}{output_name} is missing.')
|
||||
|
||||
if output_config.type == 'object':
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
|
||||
@@ -1,28 +1,21 @@
|
||||
import threading
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
default_retrieval_model = {
|
||||
@@ -106,10 +99,45 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
available_datasets.append(dataset)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
all_documents = self._single_retrieve(available_datasets, node_data, query)
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
all_documents = dataset_retrieval.single_retrieve(
|
||||
available_datasets=available_datasets,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
all_documents = self._multiple_retrieve(available_datasets, node_data, query)
|
||||
all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
|
||||
self.user_from.value,
|
||||
available_datasets, query,
|
||||
node_data.multiple_retrieval_config.top_k,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
node_data.multiple_retrieval_config.reranking_model.model)
|
||||
|
||||
context_list = []
|
||||
if all_documents:
|
||||
@@ -184,84 +212,6 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
variable_mapping['query'] = node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def _single_retrieve(self, available_datasets, node_data, query):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
message_tool = PromptMessageTool(
|
||||
name=dataset.id,
|
||||
description=description,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
tools.append(message_tool)
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
dataset_id = None
|
||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||
dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance,
|
||||
self.user_id, self.tenant_id)
|
||||
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
function_call_router = FunctionCallMultiDatasetRouter()
|
||||
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
# get retrieval method
|
||||
retrival_method = retrieval_model_config['search_method']
|
||||
# get reranking model
|
||||
reranking_model=retrieval_model_config['reranking_model'] \
|
||||
if retrieval_model_config['reranking_enable'] else None
|
||||
# get score threshold
|
||||
score_threshold = .0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
self._on_query(query, [dataset_id])
|
||||
if results:
|
||||
self._on_retrival_end(results)
|
||||
return results
|
||||
return []
|
||||
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
@@ -332,112 +282,3 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _multiple_retrieve(self, available_datasets, node_data, query):
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
for dataset in available_datasets:
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset.id,
|
||||
'query': query,
|
||||
'top_k': node_data.multiple_retrieval_config.top_k,
|
||||
'all_documents': all_documents,
|
||||
})
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
model_type=ModelType.RERANK,
|
||||
model=node_data.multiple_retrieval_config.reranking_model.model
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.top_k)
|
||||
self._on_query(query, dataset_ids)
|
||||
if all_documents:
|
||||
self._on_retrival_end(all_documents)
|
||||
return all_documents
|
||||
|
||||
def _on_retrival_end(self, documents: list[Document]) -> None:
|
||||
"""Handle retrival end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _on_query(self, query: str, dataset_ids: list[str]) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
if not query:
|
||||
return
|
||||
for dataset_id in dataset_ids:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source_app_id=self.app_id,
|
||||
created_by_role=self.user_from.value,
|
||||
created_by=self.user_id
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageContentType
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -434,6 +434,22 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
vision_enabled = node_data.vision.enabled
|
||||
for prompt_message in prompt_messages:
|
||||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
|
||||
prompt_message_content.append(content_item)
|
||||
elif content_item.type == PromptMessageContentType.TEXT:
|
||||
prompt_message_content.append(content_item)
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
prompt_message.content = prompt_message_content
|
||||
elif (len(prompt_message_content) == 1
|
||||
and prompt_message_content[0].type == PromptMessageContentType.TEXT):
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
@@ -65,7 +65,9 @@ class QuestionClassifierNode(LLMNode):
|
||||
categories = [_class.name for _class in node_data.classes]
|
||||
try:
|
||||
result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
categories = result_text_json.get('categories', [])
|
||||
categories_result = result_text_json.get('categories', [])
|
||||
if categories_result:
|
||||
categories = categories_result
|
||||
except Exception:
|
||||
logging.error(f"Failed to parse result text: {result_text}")
|
||||
try:
|
||||
@@ -89,14 +91,24 @@ class QuestionClassifierNode(LLMNode):
|
||||
inputs=variables,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
edge_source_handle=classes_map.get(categories[0], None)
|
||||
edge_source_handle=classes_map.get(categories[0], None),
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
@@ -16,6 +17,7 @@ from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
@@ -219,7 +221,8 @@ class WorkflowEngineManager:
|
||||
raise ValueError('node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls(
|
||||
@@ -252,11 +255,40 @@ class WorkflowEngineManager:
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
||||
# get value
|
||||
value = user_inputs.get(variable_key)
|
||||
|
||||
# temp fix for image type
|
||||
if node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(value, list):
|
||||
node_data = node_instance.node_data
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
|
||||
detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
|
||||
for item in value:
|
||||
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
|
||||
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
|
||||
file = FileVar(
|
||||
tenant_id=workflow.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=item.get(
|
||||
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
|
||||
)
|
||||
new_value.append(file)
|
||||
|
||||
if new_value:
|
||||
value = new_value
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.append_variable(
|
||||
node_id=variable_node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=user_inputs.get(variable_key)
|
||||
value=value
|
||||
)
|
||||
# run node
|
||||
node_run_result = node_instance.run(
|
||||
|
||||
@@ -28,6 +28,7 @@ def init_app(app: Flask) -> Celery:
|
||||
|
||||
celery_app.conf.update(
|
||||
result_backend=app.config["CELERY_RESULT_BACKEND"],
|
||||
broker_connection_retry_on_startup=True,
|
||||
)
|
||||
|
||||
if app.config["BROKER_USE_SSL"]:
|
||||
|
||||
@@ -815,7 +815,7 @@ class Message(db.Model):
|
||||
@property
|
||||
def workflow_run(self):
|
||||
if self.workflow_run_id:
|
||||
from api.models.workflow import WorkflowRun
|
||||
from .workflow import WorkflowRun
|
||||
return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
|
||||
|
||||
return None
|
||||
|
||||
@@ -299,6 +299,10 @@ class WorkflowRun(db.Model):
|
||||
Message.workflow_run_id == self.id
|
||||
).first()
|
||||
|
||||
@property
|
||||
def workflow(self):
|
||||
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
|
||||
|
||||
|
||||
class WorkflowNodeExecutionTriggeredFrom(Enum):
|
||||
"""
|
||||
|
||||
@@ -34,20 +34,27 @@ redis[hiredis]~=5.0.3
|
||||
openpyxl==3.1.2
|
||||
chardet~=5.1.0
|
||||
python-docx~=1.1.0
|
||||
pypdfium2==4.16.0
|
||||
pypdfium2~=4.17.0
|
||||
resend~=0.7.0
|
||||
pyjwt~=2.8.0
|
||||
anthropic~=0.23.1
|
||||
newspaper3k==0.2.8
|
||||
google-api-python-client==2.90.0
|
||||
wikipedia==1.4.0
|
||||
readabilipy==0.2.0
|
||||
google-ai-generativelanguage==0.6.1
|
||||
google-api-core==2.18.0
|
||||
google-api-python-client==2.90.0
|
||||
google-auth==2.29.0
|
||||
google-auth-httplib2==0.2.0
|
||||
google-generativeai==0.5.0
|
||||
google-search-results==2.4.2
|
||||
googleapis-common-protos==1.63.0
|
||||
replicate~=0.22.0
|
||||
websocket-client~=1.7.0
|
||||
dashscope[tokenizer]~=1.14.0
|
||||
huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
transformers~=4.35.0
|
||||
tokenizers~=0.15.0
|
||||
pandas==1.5.3
|
||||
xinference-client==0.9.4
|
||||
safetensors==0.3.2
|
||||
@@ -55,13 +62,12 @@ zhipuai==1.0.7
|
||||
werkzeug~=3.0.1
|
||||
pymilvus==2.3.0
|
||||
qdrant-client==1.7.3
|
||||
cohere~=4.44
|
||||
cohere~=5.2.4
|
||||
pyyaml~=6.0.1
|
||||
numpy~=1.25.2
|
||||
unstructured[docx,pptx,msg,md,ppt]~=0.10.27
|
||||
bs4~=0.0.1
|
||||
markdown~=3.5.1
|
||||
google-generativeai~=0.3.2
|
||||
httpx[socks]~=0.24.1
|
||||
matplotlib~=3.8.2
|
||||
yfinance~=0.2.35
|
||||
@@ -75,4 +81,4 @@ twilio==9.0.0
|
||||
qrcode~=7.4.2
|
||||
azure-storage-blob==12.9.0
|
||||
azure-identity==1.15.0
|
||||
lxml==5.1.0
|
||||
lxml==5.1.0
|
||||
|
||||
@@ -221,7 +221,8 @@ class AppService:
|
||||
"name": app.name,
|
||||
"mode": app.mode,
|
||||
"icon": app.icon,
|
||||
"icon_background": app.icon_background
|
||||
"icon_background": app.icon_background,
|
||||
"description": app.description
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
@@ -13,8 +16,9 @@ class ConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int,
|
||||
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None,
|
||||
exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[list] = None,
|
||||
exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
@@ -24,6 +28,7 @@ class ConversationService:
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value)
|
||||
)
|
||||
|
||||
if include_ids is not None:
|
||||
@@ -32,9 +37,6 @@ class ConversationService:
|
||||
if exclude_ids is not None:
|
||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
if exclude_debug_conversation:
|
||||
base_query = base_query.filter(Conversation.override_model_configs == None)
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(
|
||||
Conversation.id == last_id,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user