mirror of
https://github.com/langgenius/dify.git
synced 2026-01-09 07:44:12 +00:00
Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e0f72d2791 | ||
|
|
3e51710fe6 | ||
|
|
7bfdca7a53 | ||
|
|
48d5628fd4 | ||
|
|
c8fb619d37 | ||
|
|
57024614bd | ||
|
|
a31b502668 | ||
|
|
e58c3ac374 | ||
|
|
00f4e6ec44 | ||
|
|
6355e61eb8 | ||
|
|
27828f44b9 | ||
|
|
9525ca08b9 | ||
|
|
501caf0a69 | ||
|
|
c17baef172 | ||
|
|
21ade71bad | ||
|
|
23e02d8eb0 | ||
|
|
86286e1ac8 | ||
|
|
7bbe12b2bd | ||
|
|
e65a2a400d | ||
|
|
741079f317 | ||
|
|
0f5d4fd11b | ||
|
|
8eae206715 | ||
|
|
7434d44412 | ||
|
|
8394bbd47f | ||
|
|
14a2eeba0c | ||
|
|
a18dde9b0d | ||
|
|
8438d820ad | ||
|
|
e19ad023d2 | ||
|
|
0695f08f05 | ||
|
|
22ab4721e2 | ||
|
|
51f23c5dc2 | ||
|
|
1f48e3d44a | ||
|
|
0113627d7b | ||
|
|
0a5de0ff0b | ||
|
|
9c4bad8f1e | ||
|
|
c7783dbd6c | ||
|
|
ee9c7e204f | ||
|
|
483dcb6340 | ||
|
|
9ad7b65996 | ||
|
|
ec1659cba0 | ||
|
|
09a8db10d4 | ||
|
|
f3323beaca | ||
|
|
275973da8c | ||
|
|
e2c89a9487 | ||
|
|
869690c485 | ||
|
|
a3c7c07ecc | ||
|
|
dc8a8af117 | ||
|
|
6c28e1e69a | ||
|
|
0e1163f698 | ||
|
|
8654415f33 | ||
|
|
1a6ad05a23 | ||
|
|
1d91535ba6 | ||
|
|
8799c888e3 | ||
|
|
d7209d9057 | ||
|
|
5960103cb8 | ||
|
|
2ffea39a5c | ||
|
|
1e76b1bf2d | ||
|
|
2022ca1d52 | ||
|
|
e1319d1a2d | ||
|
|
a61df6cb03 | ||
|
|
790b885d0a | ||
|
|
1a2eacc5a6 | ||
|
|
f7a2f7a727 | ||
|
|
a4adca595a | ||
|
|
c51e179db8 | ||
|
|
b582fc13c3 | ||
|
|
add33cb5e6 | ||
|
|
83105d0d8f | ||
|
|
7b0818b8e5 | ||
|
|
28cd3a8c9f | ||
|
|
0355645a0e | ||
|
|
cb7a608d75 | ||
|
|
bdb0d77227 | ||
|
|
149102927b | ||
|
|
d8c0d722d2 | ||
|
|
cb7be3767c | ||
|
|
34bf2877c8 | ||
|
|
3ebec8fa41 | ||
|
|
f877d19c6a | ||
|
|
a63a9c7d45 | ||
|
|
1779cea6e3 | ||
|
|
26eff330f9 |
25
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
25
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -14,22 +14,35 @@ body:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Description of the new feature / enhancement
|
||||
placeholder: What is the expected behavior of the proposed feature?
|
||||
label: 1. Is this request related to a challenge you're experiencing?
|
||||
placeholder: Please describe the specific scenario or problem you're facing as clearly as possible. For instance "I was trying to use [feature] for [specific task], and [what happened]... It was frustrating because...."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Scenario when this would be used?
|
||||
placeholder: What is the scenario this would be used? Why is this important to your workflow as a dify user?
|
||||
label: 2. Describe the feature you'd like to see
|
||||
placeholder: Think about what you want to achieve and how this feature will help you. Sketches, flow diagrams, or any visual representation will be a major plus.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Supporting information
|
||||
placeholder: "Having additional evidence, data, tweets, blog posts, research, ... anything is extremely helpful. This information provides context to the scenario that may otherwise be lost."
|
||||
label: 3. How will this feature improve your workflow or experience?
|
||||
placeholder: Tell us how this change will benefit your work. This helps us prioritize based on user impact.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 4. Additional context or comments
|
||||
placeholder: (Any other information, comments, documentations, links, or screenshots that would provide more clarity. This is the place to add anything else not covered above.)
|
||||
validations:
|
||||
required: false
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 5. Can you help us with this feature?
|
||||
description: Let us know! This is not a commitment, but a starting point for collaboration.
|
||||
options:
|
||||
- label: I am interested in contributing to this feature.
|
||||
required: false
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: Please limit one request per issue.
|
||||
|
||||
@@ -4,10 +4,6 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
- feat/model-runtime
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
||||
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@@ -4,9 +4,6 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
|
||||
concurrency:
|
||||
group: dep-${{ github.head_ref || github.run_id }}
|
||||
@@ -24,7 +21,7 @@ jobs:
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 18
|
||||
node-version: 20
|
||||
cache: yarn
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
|
||||
26
.github/workflows/tool-tests.yaml
vendored
Normal file
26
.github/workflows/tool-tests.yaml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Run Tool Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: ./api/requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r ./api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest ./api/tests/integration_tests/tools/test_all_provider.py
|
||||
172
CONTRIBUTING.md
172
CONTRIBUTING.md
@@ -1,66 +1,158 @@
|
||||
# Contributing
|
||||
So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly.
|
||||
|
||||
Thanks for your interest in [Dify](https://dify.ai) and for wanting to contribute! Before you begin, read the
|
||||
[code of conduct](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md) and check out the
|
||||
[existing issues](https://github.com/langgenius/langgenius-gateway/issues).
|
||||
This document describes how to set up your development environment to build and test [Dify](https://dify.ai).
|
||||
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
|
||||
|
||||
### Install dependencies
|
||||
This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve.
|
||||
|
||||
You need to install and configure the following dependencies on your machine to build [Dify](https://dify.ai):
|
||||
In terms of licensing, please take a minute to read our short [License and Contributor Agreement](./license). The community also adheres to the [code of conduct](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
## Before you jump in
|
||||
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
|
||||
### Feature requests:
|
||||
|
||||
* If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try.
|
||||
|
||||
* If you want to pick one up from the existing issues, simply drop a comment below it saying so.
|
||||
|
||||
|
||||
|
||||
A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes.
|
||||
|
||||
Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment:
|
||||
|
||||
| Member | Scope |
|
||||
| ------------------------------------------------------------ | ---------------------------------------------------- |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | Architecting Agents |
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG pipeline design |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | Building workflow orchestrations |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Making our frontend a breeze to use |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Developer experience, points of contact for anything |
|
||||
| [@takatost](https://github.com/takatost) | Overall product direction and architecture |
|
||||
|
||||
How we prioritize:
|
||||
|
||||
| Feature Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| High-Priority Features as being labeled by a team member | High Priority |
|
||||
| Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority |
|
||||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
|
||||
### Anything else (e.g. bug report, performance optimization, typo correction):
|
||||
|
||||
* Start coding right away.
|
||||
|
||||
How we prioritize:
|
||||
|
||||
| Issue Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| Bugs in core functions (cannot login, applications not working, security loopholes) | Critical |
|
||||
| Non-critical bugs, performance boosts | Medium Priority |
|
||||
| Minor fixes (typos, confusing but working UI) | Low Priority |
|
||||
|
||||
|
||||
## Installing
|
||||
|
||||
Here are the steps to set up Dify for development:
|
||||
|
||||
### 1. Fork this repository
|
||||
|
||||
### 2. Clone the repo
|
||||
|
||||
Clone the forked repository from your terminal:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
```
|
||||
|
||||
### 3. Verify dependencies
|
||||
|
||||
Dify requires the following dependencies to build, make sure they're installed on your system:
|
||||
|
||||
- [Git](http://git-scm.com/)
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
|
||||
## Local development
|
||||
### 4. Installations
|
||||
|
||||
To set up a working development environment, just fork the project git repository and install the backend and frontend dependencies using the proper package manager and create run the docker-compose stack.
|
||||
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
||||
|
||||
### Fork the repository
|
||||
Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot.
|
||||
|
||||
you need to fork the [repository](https://github.com/langgenius/dify).
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
### Clone the repo
|
||||
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
|
||||
|
||||
Clone your GitHub forked repository:
|
||||
## Developing
|
||||
|
||||
If you are adding a model provider, [this guide](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md) is for you.
|
||||
|
||||
If you are adding a tool provider to Agent or Workflow, [this guide](./api/core/tools/README.md) is for you.
|
||||
|
||||
To help you quickly navigate where your contribution fits, a brief, annotated outline of Dify's backend & frontend is as follows:
|
||||
|
||||
### Backend
|
||||
|
||||
Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
[api/]
|
||||
├── constants // Constant settings used throughout code base.
|
||||
├── controllers // API route definitions and request handling logic.
|
||||
├── core // Core application orchestration, model integrations, and tools.
|
||||
├── docker // Docker & containerization related configurations.
|
||||
├── events // Event handling and processing
|
||||
├── extensions // Extensions with 3rd party frameworks/platforms.
|
||||
├── fields // field definitions for serialization/marshalling.
|
||||
├── libs // Reusable libraries and helpers.
|
||||
├── migrations // Scripts for database migration.
|
||||
├── models // Database models & schema definitions.
|
||||
├── services // Specifies business logic.
|
||||
├── storage // Private key storage.
|
||||
├── tasks // Handling of async tasks and background jobs.
|
||||
└── tests
|
||||
```
|
||||
|
||||
### Install backend
|
||||
### Frontend
|
||||
|
||||
To learn how to install the backend application, please refer to the [Backend README](api/README.md).
|
||||
The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization.
|
||||
|
||||
### Install frontend
|
||||
```
|
||||
[web/]
|
||||
├── app // layouts, pages, and components
|
||||
│ ├── (commonLayout) // common layout used throughout the app
|
||||
│ ├── (shareLayout) // layouts specifically shared across token-specific sessions
|
||||
│ ├── activate // activate page
|
||||
│ ├── components // shared by pages and layouts
|
||||
│ ├── install // install page
|
||||
│ ├── signin // signin page
|
||||
│ └── styles // globally shared styles
|
||||
├── assets // Static assets
|
||||
├── bin // scripts ran at build step
|
||||
├── config // adjustable settings and options
|
||||
├── context // shared contexts used by different portions of the app
|
||||
├── dictionaries // Language-specific translate files
|
||||
├── docker // container configurations
|
||||
├── hooks // Reusable hooks
|
||||
├── i18n // Internationalization configuration
|
||||
├── models // describes data models & shapes of API responses
|
||||
├── public // meta assets like favicon
|
||||
├── service // specifies shapes of API actions
|
||||
├── test
|
||||
├── types // descriptions of function params and return values
|
||||
└── utils // Shared utility functions
|
||||
```
|
||||
|
||||
To learn how to install the frontend application, please refer to the [Frontend README](web/README.md).
|
||||
## Submitting your PR
|
||||
|
||||
### Visit dify in your browser
|
||||
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
|
||||
|
||||
Finally, you can now visit [http://localhost:3000](http://localhost:3000) to view the [Dify](https://dify.ai) in local environment.
|
||||
And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md).
|
||||
|
||||
## Getting Help
|
||||
|
||||
## Create a pull request
|
||||
|
||||
After making your changes, open a pull request (PR). Once you submit your pull request, others from the Dify team/community will review it with you.
|
||||
|
||||
Did you have an issue, like a merge conflict, or don't know how to open a pull request? Check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests) on how to resolve merge conflicts and other issues. Once your PR has been merged, you will be proudly listed as a contributor in the [contributor chart](https://github.com/langgenius/langgenius-gateway/graphs/contributors).
|
||||
|
||||
## Community channels
|
||||
|
||||
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
|
||||
|
||||
|
||||
### Provider Integrations
|
||||
If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
|
||||
|
||||
|
||||
### i18n (Internationalization) Support
|
||||
|
||||
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
|
||||
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.
|
||||
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/AhzKf7dNgk) for a quick chat.
|
||||
|
||||
52
README.md
52
README.md
@@ -26,23 +26,25 @@
|
||||
|
||||

|
||||
|
||||
## Use Cloud Services
|
||||
|
||||
[Dify.AI Cloud](https://dify.ai) provides all the capabilities of the open-source version, and includes 200 free requests to OpenAI GPT-3.5.
|
||||
|
||||
## Why Dify
|
||||
## Using our Cloud Services
|
||||
|
||||
Dify is model-agnostic and boasts a comprehensive tech stack compared to hardcoded development libraries like LangChain. Unlike OpenAI's Assistants API, Dify allows for full local deployment of services.
|
||||
You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabilities of the self-deployed version, and includes 200 free requests to OpenAI GPT-3.5.
|
||||
|
||||
## Dify vs. LangChain vs. Assistants API
|
||||
|
||||
| Feature | Dify.AI | Assistants API | LangChain |
|
||||
|---------|---------|----------------|-----------|
|
||||
| **Programming Approach** | API-oriented | API-oriented | Python Code-oriented |
|
||||
| **Ecosystem Strategy** | Open Source | Closed and Commercial | Open Source |
|
||||
| **Ecosystem Strategy** | Open Source | Close Source | Open Source |
|
||||
| **RAG Engine** | Supported | Supported | Not Supported |
|
||||
| **Prompt IDE** | Included | Included | None |
|
||||
| **Supported LLMs** | Rich Variety | Only GPT | Rich Variety |
|
||||
| **Supported LLMs** | Rich Variety | OpenAI-only | Rich Variety |
|
||||
| **Local Deployment** | Supported | Not Supported | Not Applicable |
|
||||
|
||||
|
||||
|
||||
## Features
|
||||
|
||||

|
||||
@@ -59,7 +61,7 @@ Dify is model-agnostic and boasts a comprehensive tech stack compared to hardcod
|
||||
|
||||
## Before You Start
|
||||
|
||||
**Star us, and you'll get instant notifications for all new releases on GitHub!**
|
||||
**Star us on GitHub, and be instantly notified for new releases!**
|
||||
|
||||

|
||||
|
||||
@@ -103,17 +105,39 @@ If you need to customize the configuration, please refer to the comments in our
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
|
||||
|
||||
### Contributors
|
||||
|
||||
<a href="https://github.com/langgenius/dify/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
|
||||
</a>
|
||||
|
||||
### Translations
|
||||
|
||||
We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README_EN.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/AhzKf7dNgk).
|
||||
|
||||
## Community & Support
|
||||
|
||||
We welcome you to contribute to Dify to help make Dify better in various ways, submitting code, issues, new ideas, or sharing the interesting and useful AI applications you have created based on Dify. At the same time, we also welcome you to share Dify at different events, conferences, and social media.
|
||||
* [Canny](https://feedback.dify.ai/). Best for: sharing feedback and checking out our feature roadmap.
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [Email Support](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI.
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
* [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
* [Business Contact](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Best for: business inquiries of licensing Dify.AI for commercial use.
|
||||
|
||||
- [Roadmap and Feedback](https://feedback.dify.ai/). Best for: sharing feedback and checking out our feature roadmap.
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs and errors you encounter using Dify.AI, see the [Contribution Guide](CONTRIBUTING.md).
|
||||
- [Email Support](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI.
|
||||
- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
- [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
- [Business License](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Best for: business inquiries of licensing Dify.AI for commercial use.
|
||||
### Direct Meetings
|
||||
|
||||
**Help us make Dify better. Reach out directly to us**.
|
||||
|
||||
| Point of Contact | Purpose |
|
||||
| :----------------------------------------------------------: | :----------------------------------------------------------: |
|
||||
| <a href='https://cal.com/guchenhe/15min' target='_blank'><img src='https://i.postimg.cc/fWBqSmjP/Git-Hub-README-Button-3x.png' border='0' alt='Git-Hub-README-Button-3x' height="60" width="214"/></a> | Product design feedback, user experience discussions, feature planning and roadmaps. |
|
||||
| <a href='https://cal.com/pinkbanana' target='_blank'><img src='https://i.postimg.cc/LsRTh87D/Git-Hub-README-Button-2x.png' border='0' alt='Git-Hub-README-Button-2x' height="60" width="225"/></a> | Technical support, issues, or feature requests |
|
||||
|
||||
## Security Disclosure
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ CONSOLE_WEB_URL=http://127.0.0.1:3000
|
||||
SERVICE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP base URL
|
||||
APP_API_URL=http://127.0.0.1:5001
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Files URL
|
||||
@@ -102,10 +101,10 @@ NOTION_CLIENT_ID=you-client-id
|
||||
NOTION_INTERNAL_SECRET=you-internal-secret
|
||||
|
||||
# Hosted Model Credentials
|
||||
HOSTED_OPENAI_ENABLED=false
|
||||
HOSTED_OPENAI_API_KEY=
|
||||
HOSTED_OPENAI_API_BASE=
|
||||
HOSTED_OPENAI_API_ORGANIZATION=
|
||||
HOSTED_OPENAI_TRIAL_ENABLED=false
|
||||
HOSTED_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_OPENAI_PAID_ENABLED=false
|
||||
|
||||
@@ -114,9 +113,9 @@ HOSTED_AZURE_OPENAI_API_KEY=
|
||||
HOSTED_AZURE_OPENAI_API_BASE=
|
||||
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
|
||||
|
||||
HOSTED_ANTHROPIC_ENABLED=false
|
||||
HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_TRIAL_ENABLED=false
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
|
||||
|
||||
@@ -13,21 +13,24 @@ RUN pip install --prefix=/pkg -r requirements.txt
|
||||
# build stage
|
||||
FROM python:3.10-slim AS builder
|
||||
|
||||
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
ENV DEPLOY_ENV PRODUCTION
|
||||
ENV CONSOLE_API_URL http://127.0.0.1:5001
|
||||
ENV CONSOLE_WEB_URL http://127.0.0.1:3000
|
||||
ENV SERVICE_API_URL http://127.0.0.1:5001
|
||||
ENV APP_API_URL http://127.0.0.1:5001
|
||||
ENV APP_WEB_URL http://127.0.0.1:3000
|
||||
|
||||
EXPOSE 5001
|
||||
|
||||
# set timezone
|
||||
ENV TZ UTC
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends bash curl wget vim nodejs \
|
||||
&& apt-get install -y --no-install-recommends bash curl wget vim nodejs ffmpeg \
|
||||
&& apt-get autoremove \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from libs.passport import PassportService
|
||||
# DO NOT REMOVE BELOW
|
||||
from models import account, dataset, model, source, task, tool, web
|
||||
from models import account, dataset, model, source, task, tool, web, tools
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
@@ -124,6 +124,7 @@ def load_user_from_request(request_from_flask_login):
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
|
||||
@@ -11,6 +11,7 @@ import uuid
|
||||
|
||||
import click
|
||||
import qdrant_client
|
||||
from constants.languages import user_input_form_template
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelManager
|
||||
@@ -22,7 +23,7 @@ from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetQuery, Document
|
||||
from models.model import Account, App, AppModelConfig, Message, MessageAnnotation
|
||||
from models.model import Account, App, AppModelConfig, Message, MessageAnnotation, InstalledApp
|
||||
from models.provider import Provider, ProviderModel, ProviderQuotaType, ProviderType
|
||||
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
|
||||
from tqdm import tqdm
|
||||
@@ -583,28 +584,6 @@ def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count:
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def update_app_model_configs(batch_size):
|
||||
pre_prompt_template = '{{default_input}}'
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"zh-Hans": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "查询内容",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ DEFAULTS = {
|
||||
'CONSOLE_API_URL': 'https://cloud.dify.ai',
|
||||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'APP_API_URL': 'https://udify.app',
|
||||
'FILES_URL': '',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
@@ -39,13 +38,19 @@ DEFAULTS = {
|
||||
'CELERY_BACKEND': 'database',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_OPENAI_PAID_MIN_QUANTITY': 1,
|
||||
'HOSTED_OPENAI_PAID_MAX_QUANTITY': 1,
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
'HOSTED_ANTHROPIC_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
|
||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
|
||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 1,
|
||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 1,
|
||||
'HOSTED_MODERATION_ENABLED': 'False',
|
||||
'HOSTED_MODERATION_PROVIDERS': '',
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
@@ -66,7 +71,8 @@ def get_env(key):
|
||||
|
||||
|
||||
def get_bool_env(key):
|
||||
return get_env(key).lower() == 'true'
|
||||
value = get_env(key)
|
||||
return value.lower() == 'true' if value is not None else False
|
||||
|
||||
|
||||
def get_cors_allow_origins(env, default):
|
||||
@@ -87,7 +93,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.4.6"
|
||||
self.CURRENT_VERSION = "0.5.0"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -96,35 +102,25 @@ class Config:
|
||||
|
||||
# The backend URL prefix of the console API.
|
||||
# used to concatenate the login authorization callback or notion integration callback.
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_API_URL')
|
||||
|
||||
# The front-end URL prefix of the console web.
|
||||
# used to concatenate some front-end addresses and for CORS configuration use.
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
|
||||
# WebApp API backend Url prefix.
|
||||
# used to declare the back-end URL for the front-end API.
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_WEB_URL')
|
||||
|
||||
# WebApp Url prefix.
|
||||
# used to display WebAPP API Base Url to the front-end.
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
self.APP_WEB_URL = get_env('APP_WEB_URL')
|
||||
|
||||
# Service API Url prefix.
|
||||
# used to display Service API Base Url to the front-end.
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
self.SERVICE_API_URL = get_env('SERVICE_API_URL')
|
||||
|
||||
# File preview or download Url prefix.
|
||||
# used to display File preview or download Url to the front-end or as Multi-model inputs;
|
||||
# Url is signed and has expiration time.
|
||||
self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
|
||||
|
||||
# Fallback Url prefix.
|
||||
# Will be deprecated in the future.
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
# You can generate a strong key using `openssl rand -base64 42`.
|
||||
@@ -260,23 +256,35 @@ class Config:
|
||||
# ------------------------
|
||||
# Platform Configurations.
|
||||
# ------------------------
|
||||
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_OPENAI_PAID_MIN_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_OPENAI_PAID_MAX_QUANTITY = int(get_env('HOSTED_OPENAI_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
|
||||
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT'))
|
||||
|
||||
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
|
||||
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
|
||||
self.HOSTED_ANTHROPIC_TRIAL_ENABLED = get_bool_env('HOSTED_ANTHROPIC_TRIAL_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
|
||||
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
|
||||
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
|
||||
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||
|
||||
self.HOSTED_MINIMAX_ENABLED = get_bool_env('HOSTED_MINIMAX_ENABLED')
|
||||
self.HOSTED_SPARK_ENABLED = get_bool_env('HOSTED_SPARK_ENABLED')
|
||||
self.HOSTED_ZHIPUAI_ENABLED = get_bool_env('HOSTED_ZHIPUAI_ENABLED')
|
||||
|
||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
326
api/constants/languages.py
Normal file
326
api/constants/languages.py
Normal file
@@ -0,0 +1,326 @@
|
||||
|
||||
import json
|
||||
from models.model import AppModelConfig
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
'zh-Hans': 'Asia/Shanghai',
|
||||
'pt-BR': 'America/Sao_Paulo',
|
||||
'es-ES': 'Europe/Madrid',
|
||||
'fr-FR': 'Europe/Paris',
|
||||
'de-DE': 'Europe/Berlin',
|
||||
'ja-JP': 'Asia/Tokyo',
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
}
|
||||
|
||||
def supported_language(lang):
|
||||
if lang in languages:
|
||||
return lang
|
||||
|
||||
error = ('{lang} is not a valid language.'
|
||||
.format(lang=lang))
|
||||
raise ValueError(error)
|
||||
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"zh-Hans": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "查询内容",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"pt-BR": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Consulta",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"es-ES": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Consulta",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
demo_model_templates = {
|
||||
'en-US': [
|
||||
{
|
||||
'name': 'Translation Assistant',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'A multilingual translator that provides translation capabilities in multiple languages, translating user input into the language they need.',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "Please translate the following text into {{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Target Language",
|
||||
"description": "The language you want to translate into.",
|
||||
"type": "select",
|
||||
"default": "Chinese",
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Target Language",
|
||||
"variable": "target_language",
|
||||
"description": "The language you want to translate into.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'AI Front-end Interviewer',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'A simulated front-end interviewer that tests the skill level of front-end development through questioning.',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': 'Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ',
|
||||
'prompt_template': "You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ',
|
||||
suggested_questions=None,
|
||||
pre_prompt="You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
|
||||
'zh-Hans': [
|
||||
{
|
||||
'name': '翻译助手',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': '一个多语言翻译器,提供多种语言翻译能力,将用户输入的文本翻译成他们需要的语言。',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "目标语言",
|
||||
"description": "翻译的目标语言",
|
||||
"type": "select",
|
||||
"default": "中文",
|
||||
"options": [
|
||||
"中文",
|
||||
"英文",
|
||||
"日语",
|
||||
"法语",
|
||||
"俄语",
|
||||
"德语",
|
||||
"西班牙语",
|
||||
"韩语",
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "目标语言",
|
||||
"variable": "target_language",
|
||||
"description": "翻译的目标语言",
|
||||
"default": "中文",
|
||||
"required": True,
|
||||
'options': [
|
||||
"中文",
|
||||
"英文",
|
||||
"日语",
|
||||
"法语",
|
||||
"俄语",
|
||||
"德语",
|
||||
"西班牙语",
|
||||
"韩语",
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'AI 前端面试官',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': '一个模拟的前端面试官,通过提问的方式对前端开发的技能水平进行检验。',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': '你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。',
|
||||
'prompt_template': "你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。',
|
||||
suggested_questions=None,
|
||||
pre_prompt="你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
|
||||
}
|
||||
@@ -96,258 +96,3 @@ model_templates = {
|
||||
}
|
||||
|
||||
|
||||
demo_model_templates = {
|
||||
'en-US': [
|
||||
{
|
||||
'name': 'Translation Assistant',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'A multilingual translator that provides translation capabilities in multiple languages, translating user input into the language they need.',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "Please translate the following text into {{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Target Language",
|
||||
"description": "The language you want to translate into.",
|
||||
"type": "select",
|
||||
"default": "Chinese",
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Target Language",
|
||||
"variable": "target_language",
|
||||
"description": "The language you want to translate into.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'AI Front-end Interviewer',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'A simulated front-end interviewer that tests the skill level of front-end development through questioning.',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': 'Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ',
|
||||
'prompt_template': "You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ',
|
||||
suggested_questions=None,
|
||||
pre_prompt="You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
|
||||
'zh-Hans': [
|
||||
{
|
||||
'name': '翻译助手',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': '一个多语言翻译器,提供多种语言翻译能力,将用户输入的文本翻译成他们需要的语言。',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "目标语言",
|
||||
"description": "翻译的目标语言",
|
||||
"type": "select",
|
||||
"default": "中文",
|
||||
"options": [
|
||||
"中文",
|
||||
"英文",
|
||||
"日语",
|
||||
"法语",
|
||||
"俄语",
|
||||
"德语",
|
||||
"西班牙语",
|
||||
"韩语",
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "目标语言",
|
||||
"variable": "target_language",
|
||||
"description": "翻译的目标语言",
|
||||
"default": "中文",
|
||||
"required": True,
|
||||
'options': [
|
||||
"中文",
|
||||
"英文",
|
||||
"日语",
|
||||
"法语",
|
||||
"俄语",
|
||||
"德语",
|
||||
"西班牙语",
|
||||
"韩语",
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'AI 前端面试官',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': '一个模拟的前端面试官,通过提问的方式对前端开发的技能水平进行检验。',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': '你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。',
|
||||
'prompt_template': "你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。',
|
||||
suggested_questions=None,
|
||||
pre_prompt="你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -16,7 +16,5 @@ from .billing import billing
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
# Import explore controllers
|
||||
from .explore import audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message
|
||||
# Import universal chat controllers
|
||||
from .universal_chat import audio, chat, conversation, message, parameter
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
|
||||
@@ -6,7 +6,7 @@ from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import supported_language
|
||||
from constants.languages import supported_language
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from constants.model_template import demo_model_templates, model_templates
|
||||
from constants.model_template import model_templates
|
||||
from constants.languages import demo_model_templates, languages
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
@@ -16,14 +17,15 @@ from events.app_event import app_was_created, app_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (app_detail_fields, app_detail_fields_with_site, app_pagination_fields,
|
||||
template_list_fields)
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, inputs, marshal_with, reqparse
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from models.tools import ApiToolProvider
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
def _get_app(app_id, tenant_id):
|
||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||
if not app:
|
||||
@@ -42,14 +44,31 @@ class AppListApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'completion', 'all'], default='all', location='args', required=False)
|
||||
parser.add_argument('name', type=str, location='args', required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
filters = [
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == False
|
||||
]
|
||||
|
||||
if args['mode'] == 'completion':
|
||||
filters.append(App.mode == 'completion')
|
||||
elif args['mode'] == 'chat':
|
||||
filters.append(App.mode == 'chat')
|
||||
else:
|
||||
pass
|
||||
|
||||
if 'name' in args and args['name']:
|
||||
filters.append(App.name.ilike(f'%{args["name"]}%'))
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == False).order_by(App.created_at.desc()),
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False)
|
||||
error_out=False
|
||||
)
|
||||
|
||||
return app_models
|
||||
|
||||
@@ -62,7 +81,7 @@ class AppListApi(Resource):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('mode', type=str, choices=['completion', 'chat'], location='json')
|
||||
parser.add_argument('mode', type=str, choices=['completion', 'chat', 'assistant'], location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('model_config', type=dict, location='json')
|
||||
@@ -178,7 +197,7 @@ class AppListApi(Resource):
|
||||
app_was_created.send(app)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
|
||||
class AppTemplateApi(Resource):
|
||||
|
||||
@@ -193,7 +212,7 @@ class AppTemplateApi(Resource):
|
||||
|
||||
templates = demo_model_templates.get(interface_language)
|
||||
if not templates:
|
||||
templates = demo_model_templates.get('en-US')
|
||||
templates = demo_model_templates.get(languages[0])
|
||||
|
||||
return {'data': templates}
|
||||
|
||||
|
||||
@@ -32,9 +32,10 @@ class ChatMessageAudioApi(Resource):
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
promot=app_model.app_model_config.pre_prompt
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -62,6 +63,48 @@ class ChatMessageAudioApi(Resource):
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
|
||||
class ChatMessageTextApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, None)
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
streaming=False
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
|
||||
@@ -163,29 +163,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -241,27 +241,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except MessageNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(
|
||||
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -7,7 +7,7 @@ from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from libs.helper import supported_language
|
||||
from constants.languages import supported_language
|
||||
from libs.login import login_required
|
||||
from models.model import Site
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@@ -6,7 +6,8 @@ from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.helper import email, str_len, supported_language, timezone
|
||||
from libs.helper import email, str_len, timezone
|
||||
from constants.languages import supported_language
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus, Tenant
|
||||
from services.account_service import RegisterService
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
@@ -106,11 +107,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
preferred_lang = request.accept_languages.best_match(['zh', 'en'])
|
||||
if preferred_lang == 'zh':
|
||||
interface_language = 'zh-Hans'
|
||||
preferred_lang = request.accept_languages.best_match(languages)
|
||||
if preferred_lang and preferred_lang in languages:
|
||||
interface_language = preferred_lang
|
||||
else:
|
||||
interface_language = 'en-US'
|
||||
interface_language = languages[0]
|
||||
account.interface_language = interface_language
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from libs.login import login_required
|
||||
from models.dataset import Document, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@@ -97,7 +97,8 @@ class DatasetListApi(Resource):
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=('high_quality', 'economy'),
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -177,8 +178,9 @@ class DatasetApi(Resource):
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=('high_quality', 'economy'),
|
||||
help='Invalid indexing technique.')
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
@@ -256,7 +258,9 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
|
||||
@@ -9,7 +9,7 @@ from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
from services.file_service import FileService, ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
@@ -71,11 +71,7 @@ class FileSupportTypeApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
if etl_type == 'Unstructured':
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
|
||||
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
|
||||
else:
|
||||
allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||
return {'allowed_extensions': allowed_extensions}
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
@@ -59,6 +59,48 @@ class ChatAudioApi(InstalledAppResource):
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.text_to_speech_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
streaming=False
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
||||
|
||||
@@ -158,29 +158,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -33,8 +33,9 @@ class InstalledAppsListApi(Resource):
|
||||
'app_owner_tenant_id': installed_app.app_owner_tenant_id,
|
||||
'is_pinned': installed_app.is_pinned,
|
||||
'last_used_at': installed_app.last_used_at,
|
||||
"editable": current_user.role in ["owner", "admin"],
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id
|
||||
'editable': current_user.role in ["owner", "admin"],
|
||||
'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
'is_agent': installed_app.is_agent
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
]
|
||||
|
||||
@@ -17,9 +17,9 @@ from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful import marshal_with, reqparse, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import uuid_value
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@@ -29,7 +29,6 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class MessageListApi(InstalledAppResource):
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -51,7 +50,6 @@ class MessageListApi(InstalledAppResource):
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
def post(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
@@ -117,26 +115,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except MessageNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import InstalledApp
|
||||
from models.model import InstalledApp, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
@@ -27,6 +31,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'text_to_speech': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
@@ -47,6 +52,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'text_to_speech': app_model_config.text_to_speech_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
@@ -58,5 +64,42 @@ class AppParameterApi(InstalledAppResource):
|
||||
}
|
||||
}
|
||||
|
||||
class ExploreAppMetaApi(InstalledAppResource):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = installed_app.app.app_model_config
|
||||
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return meta
|
||||
|
||||
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters', endpoint='installed_app_parameters')
|
||||
api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
|
||||
|
||||
@@ -9,6 +9,7 @@ from libs.login import login_required
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from sqlalchemy import and_
|
||||
from constants.languages import languages
|
||||
|
||||
app_fields = {
|
||||
'id': fields.String,
|
||||
@@ -29,7 +30,8 @@ recommended_app_fields = {
|
||||
'is_listed': fields.Boolean,
|
||||
'install_count': fields.Integer,
|
||||
'installed': fields.Boolean,
|
||||
'editable': fields.Boolean
|
||||
'editable': fields.Boolean,
|
||||
'is_agent': fields.Boolean
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
@@ -43,7 +45,7 @@ class RecommendedAppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_fields)
|
||||
def get(self):
|
||||
language_prefix = current_user.interface_language if current_user.interface_language else 'en-US'
|
||||
language_prefix = current_user.interface_language if current_user.interface_language else languages[0]
|
||||
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
@@ -82,6 +84,7 @@ class RecommendedAppListApi(Resource):
|
||||
'install_count': recommended_app.install_count,
|
||||
'installed': installed,
|
||||
'editable': current_user.role in ['owner', 'admin'],
|
||||
"is_agent": app.is_agent
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError,
|
||||
NoAudioUploadedError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError, UnsupportedAudioTypeError)
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import request
|
||||
from models.model import AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError)
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
|
||||
class UniversalChatAudioApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text')
|
||||
@@ -1,141 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class UniversalChatApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
parser.add_argument('tools', type=list, required=True, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
# update app model config
|
||||
args['model_config'] = app_model_config.to_dict()
|
||||
args['model_config']['model']['name'] = args['model']
|
||||
args['model_config']['model']['provider'] = args['provider']
|
||||
args['model_config']['agent_mode']['tools'] = args['tools']
|
||||
|
||||
if not args['model_config']['agent_mode']['tools']:
|
||||
args['model_config']['agent_mode']['tools'] = [
|
||||
{
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
]
|
||||
else:
|
||||
args['model_config']['agent_mode']['tools'].append({
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
})
|
||||
|
||||
args['inputs'] = {}
|
||||
|
||||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True,
|
||||
is_model_config_override=True,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class UniversalChatStopApi(UniversalChatResource):
|
||||
def post(self, universal_app, task_id):
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
|
||||
api.add_resource(UniversalChatApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')
|
||||
@@ -1,110 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from fields.conversation_fields import (conversation_with_model_config_fields,
|
||||
conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class UniversalChatConversationListApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if 'pinned' in args and args['pinned'] is not None:
|
||||
pinned = True if args['pinned'] == 'true' else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationApi(UniversalChatResource):
|
||||
def delete(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_with_model_config_fields)
|
||||
def post(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationPinApi(UniversalChatResource):
|
||||
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class UniversalChatConversationUnPinApi(UniversalChatResource):
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
|
||||
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
|
||||
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
|
||||
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
|
||||
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')
|
||||
@@ -1,145 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderQuotaExceededError)
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class UniversalChatMessageListApi(UniversalChatResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
'id': fields.String,
|
||||
'chain_id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatMessageFeedbackApi(UniversalChatResource):
|
||||
def post(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||
def get(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {'data': questions}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')
|
||||
@@ -1,38 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import App
|
||||
|
||||
|
||||
class UniversalChatParameterApi(UniversalChatResource):
|
||||
"""Resource for app variables."""
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, universal_app: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = universal_app
|
||||
app_model_config = app_model.app_model_config
|
||||
app_model_config.retriever_resource = json.dumps({'enabled': True})
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')
|
||||
@@ -1,86 +0,0 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
def universal_chat_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# get universal chat app
|
||||
universal_app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == True
|
||||
).first()
|
||||
|
||||
if universal_app is None:
|
||||
# create universal app if not exists
|
||||
universal_app = App(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name='Universal Chat',
|
||||
mode='chat',
|
||||
is_universal=True,
|
||||
icon='',
|
||||
icon_background='',
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
status='normal'
|
||||
)
|
||||
|
||||
db.session.add(universal_app)
|
||||
db.session.flush()
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
provider="",
|
||||
model_id="",
|
||||
configs={},
|
||||
opening_statement='',
|
||||
suggested_questions=json.dumps([]),
|
||||
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||
speech_to_text=json.dumps({'enabled': True}),
|
||||
retriever_resource=json.dumps({'enabled': True}),
|
||||
more_like_this=None,
|
||||
sensitive_word_avoidance=None,
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-16k",
|
||||
"completion_params": {
|
||||
"max_tokens": 800,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([]),
|
||||
pre_prompt='',
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
|
||||
)
|
||||
|
||||
app_model_config.app_id = universal_app.id
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
universal_app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
return view(universal_app, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
class UniversalChatResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]
|
||||
@@ -11,7 +11,8 @@ from extensions.ext_database import db
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from libs.helper import TimestampField, supported_language, timezone
|
||||
from libs.helper import TimestampField, timezone
|
||||
from constants.languages import supported_language
|
||||
from libs.login import login_required
|
||||
from models.account import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal, marshal_with, reqparse
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.account import Account
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
account_fields = {
|
||||
@@ -64,18 +65,12 @@ class MemberInviteEmailApi(Resource):
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
|
||||
inviter=inviter)
|
||||
account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == invitee_email).first()
|
||||
account, role = account
|
||||
inviter=inviter)
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
|
||||
})
|
||||
account = marshal(account, account_fields)
|
||||
account['role'] = role
|
||||
except Exception as e:
|
||||
invitation_results.append({
|
||||
'status': 'failed',
|
||||
|
||||
@@ -1,136 +1,293 @@
|
||||
import json
|
||||
|
||||
from libs.login import login_required
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask import send_file
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.provider.tool_provider_service import ToolProviderService
|
||||
from extensions.ext_database import db
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from libs.login import login_required
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from services.tools_manage_service import ToolManageService
|
||||
|
||||
import io
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_credential_dict = {}
|
||||
for tool_name in ToolProviderName:
|
||||
tool_credential_dict[tool_name.value] = {
|
||||
'tool_name': tool_name.value,
|
||||
'is_enabled': False,
|
||||
'credentials': None
|
||||
}
|
||||
return ToolManageService.list_tool_providers(user_id, tenant_id)
|
||||
|
||||
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
|
||||
class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
for p in tool_providers:
|
||||
if p.is_enabled:
|
||||
tool_credential_dict[p.tool_name] = {
|
||||
'tool_name': p.tool_name,
|
||||
'is_enabled': p.is_enabled,
|
||||
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
|
||||
}
|
||||
|
||||
return list(tool_credential_dict.values())
|
||||
|
||||
|
||||
class ToolProviderCredentialsApi(Resource):
|
||||
return ToolManageService.list_builtin_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
|
||||
f'current_role is {current_user.current_tenant.current_role}')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
tool_provider_model = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == tenant.id,
|
||||
ToolProvider.tool_name == provider,
|
||||
).first()
|
||||
|
||||
# Only allow updating token for CUSTOM provider type
|
||||
if tool_provider_model:
|
||||
tool_provider_model.encrypted_credentials = encrypted_credentials
|
||||
tool_provider_model.is_enabled = True
|
||||
else:
|
||||
tool_provider_model = ToolProvider(
|
||||
tenant_id=tenant.id,
|
||||
tool_name=provider,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
is_enabled=True
|
||||
)
|
||||
db.session.add(tool_provider_model)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ToolProviderCredentialsValidateApi(Resource):
|
||||
|
||||
return ToolManageService.delete_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
return ToolManageService.update_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
args['credentials'],
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
|
||||
|
||||
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
args = parser.parse_args()
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
return ToolManageService.create_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
args['icon'],
|
||||
args['credentials'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
args.get('privacy_policy', ''),
|
||||
)
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
return response
|
||||
parser.add_argument('url', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.get_api_tool_provider_remote_schema(
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
args['url'],
|
||||
)
|
||||
|
||||
class ToolApiProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.list_api_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.update_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
args['original_provider'],
|
||||
args['icon'],
|
||||
args['credentials'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
args['privacy_policy'],
|
||||
)
|
||||
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.delete_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.get_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return ToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.parser_api_schema(
|
||||
schema=args['schema'],
|
||||
)
|
||||
|
||||
class ToolApiProviderPreviousTestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args['tool_name'],
|
||||
args['credentials'],
|
||||
args['parameters'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
)
|
||||
|
||||
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
||||
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
|
||||
api.add_resource(ToolProviderCredentialsValidateApi,
|
||||
'/workspaces/current/tool-providers/<provider>/credentials-validate')
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
||||
api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update')
|
||||
api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete')
|
||||
api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get')
|
||||
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
|
||||
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
|
||||
|
||||
@@ -7,3 +7,4 @@ api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview
|
||||
from . import tool_files
|
||||
47
api/controllers/files/tool_files.py
Normal file
47
api/controllers/files/tool_files.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from controllers.files import api
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.exception import BaseHTTPException
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
class ToolFilePreviewApi(Resource):
|
||||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('timestamp', type=str, required=True, location='args')
|
||||
parser.add_argument('nonce', type=str, required=True, location='args')
|
||||
parser.add_argument('sign', type=str, required=True, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not ToolFileManager.verify_file(file_id=file_id,
|
||||
timestamp=args['timestamp'],
|
||||
nonce=args['nonce'],
|
||||
sign=args['sign'],
|
||||
):
|
||||
raise Forbidden('Invalid request.')
|
||||
|
||||
try:
|
||||
result = ToolFileManager.get_file_generator_by_message_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise NotFound(f'file is not found')
|
||||
|
||||
generator, mimetype = result
|
||||
except Exception:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
api.add_resource(ToolFilePreviewApi, '/files/tools/<uuid:file_id>.<string:extension>')
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
@@ -6,5 +6,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message
|
||||
from .dataset import dataset, document, segment
|
||||
|
||||
@@ -3,7 +3,12 @@ from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import App
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
@@ -28,6 +33,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'text_to_speech': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
@@ -47,6 +53,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'text_to_speech': app_model_config.text_to_speech_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
@@ -58,5 +65,42 @@ class AppParameterApi(AppApiResource):
|
||||
}
|
||||
}
|
||||
|
||||
class AppMetaApi(AppApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return meta
|
||||
|
||||
api.add_resource(AppParameterApi, '/parameters')
|
||||
api.add_resource(AppMetaApi, '/meta')
|
||||
|
||||
@@ -10,6 +10,7 @@ from controllers.service_api.wraps import AppApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from models.model import App, AppModelConfig
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError,
|
||||
@@ -22,14 +23,15 @@ class AudioApi(AppApiResource):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
end_user=end_user
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -57,5 +59,49 @@ class AudioApi(AppApiResource):
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
|
||||
|
||||
class TextApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
streaming=False
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
api.add_resource(TextApi, '/text-to-audio')
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response, stream_with_context, request
|
||||
from flask_restful import reqparse
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
@@ -75,11 +75,18 @@ class CompletionApi(AppApiResource):
|
||||
|
||||
|
||||
class CompletionStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
def post(self, app_model, _, task_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_user_id = args.get('user')
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -139,11 +146,13 @@ class ChatApi(AppApiResource):
|
||||
|
||||
|
||||
class ChatStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
def post(self, app_model, _, task_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
end_user_id = request.get_json().get('user')
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@@ -153,29 +162,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -86,5 +86,4 @@ class ConversationRenameApi(AppApiResource):
|
||||
|
||||
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name')
|
||||
api.add_resource(ConversationApi, '/conversations')
|
||||
api.add_resource(ConversationApi, '/conversations/<uuid:c_id>', endpoint='conversation')
|
||||
api.add_resource(ConversationDetailApi, '/conversations/<uuid:c_id>', endpoint='conversation_detail')
|
||||
|
||||
@@ -37,6 +37,19 @@ class MessageListApi(AppApiResource):
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
'id': fields.String,
|
||||
'chain_id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'observation': fields.String,
|
||||
'message_files': fields.List(fields.String, attribute='files')
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
@@ -46,7 +59,8 @@ class MessageListApi(AppApiResource):
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from models.dataset import Dataset
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
@@ -68,7 +69,7 @@ class DatasetApi(DatasetApiResource):
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=('high_quality', 'economy'),
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
16
api/controllers/service_api/index.py
Normal file
16
api/controllers/service_api/index.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from flask import current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.service_api import api
|
||||
|
||||
|
||||
class IndexApi(Resource):
|
||||
def get(self):
|
||||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
"server_version": current_app.config['CURRENT_VERSION']
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(IndexApi, '/')
|
||||
@@ -75,7 +75,7 @@ def validate_dataset_token(view=None):
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == api_token.tenant_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.filter(TenantAccountJoin.role.in_(['owner', 'admin'])) \
|
||||
.one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
|
||||
@@ -3,7 +3,12 @@ from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import App
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
@@ -27,6 +32,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'text_to_speech': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
@@ -46,6 +52,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'text_to_speech': app_model_config.text_to_speech_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
@@ -57,5 +64,42 @@ class AppParameterApi(WebApiResource):
|
||||
}
|
||||
}
|
||||
|
||||
class AppMeta(WebApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return meta
|
||||
|
||||
api.add_resource(AppParameterApi, '/parameters')
|
||||
api.add_resource(AppMeta, '/meta')
|
||||
@@ -28,7 +28,7 @@ class AudioApi(WebApiResource):
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript(
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
)
|
||||
@@ -59,4 +59,43 @@ class AudioApi(WebApiResource):
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(AudioApi, '/audio-to-text')
|
||||
api.add_resource(TextApi, '/text-to-audio')
|
||||
|
||||
@@ -146,29 +146,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -14,6 +14,7 @@ from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
@@ -59,7 +60,8 @@ class MessageListApi(WebApiResource):
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
@@ -151,26 +153,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
except MessageNotExistsError:
|
||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||
except MoreLikeThisDisabledError:
|
||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||
except ProviderTokenNotInitError as ex:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||
except QuotaExceededError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -13,8 +13,8 @@ from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||
from core.helper import moderation
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.features.agent_runner import AgentRunnerFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentApplicationRunner(AppRunner):
|
||||
"""
|
||||
Agent Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run agent application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Create MessageChain
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message,
|
||||
queue_manager=queue_manager,
|
||||
message_chain=message_chain
|
||||
)
|
||||
|
||||
# init LLM Callback
|
||||
agent_llm_callback = AgentLLMCallback(
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
agent_runner = AgentRunnerFeature(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=app_orchestration_config.agent,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
callback=agent_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# agent run
|
||||
result = agent_runner.run(
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if result:
|
||||
self._save_message_chain(
|
||||
message_chain=message_chain,
|
||||
output_text=result
|
||||
)
|
||||
|
||||
if (result
|
||||
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
|
||||
and app_orchestration_config.prompt_template.simple_prompt_template
|
||||
):
|
||||
# Direct output if agent result exists and has pre prompt
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
stream=application_generate_entity.stream,
|
||||
text=result,
|
||||
usage=self._get_usage_of_all_agent_thoughts(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# As normal LLM run, agent result as context
|
||||
context = result
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_token
|
||||
all_answer_tokens += agent_thought.answer_token
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
||||
@@ -2,7 +2,8 @@ import time
|
||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import AppOrchestrationConfigEntity, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.entities.application_entities import AppOrchestrationConfigEntity, ModelConfigEntity, \
|
||||
PromptTemplateEntity, ExternalDataVariableEntity, ApplicationGenerateEntity, InvokeFrom
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -10,9 +11,12 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App
|
||||
|
||||
from models.model import App, MessageAnnotation, Message
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
@@ -199,7 +203,8 @@ class AppRunner:
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: ApplicationQueueManager,
|
||||
stream: bool) -> None:
|
||||
stream: bool,
|
||||
agent: bool = False) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@@ -210,16 +215,19 @@ class AppRunner:
|
||||
if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
else:
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
queue_manager: ApplicationQueueManager,
|
||||
agent: bool) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
@@ -232,7 +240,8 @@ class AppRunner:
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
queue_manager: ApplicationQueueManager,
|
||||
agent: bool) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@@ -244,7 +253,10 @@ class AppRunner:
|
||||
text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||
if not agent:
|
||||
queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
@@ -271,3 +283,101 @@ class AppRunner:
|
||||
llm_result=llm_result,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config_entity: app orchestration config entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = ModerationFeature()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
external_data_fetch_feature = ExternalDataFetchFeature()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
342
api/core/app_runner/assistant_app_runner.py
Normal file
342
api/core/app_runner/assistant_app_runner.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.features.assistant_cot_runner import AssistantCotApplicationRunner
|
||||
from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||
AgentEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
Assistant Application Runner
|
||||
"""
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run assistant application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
if query:
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_orchestration_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
agent_entity = app_orchestration_config.agent
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tanent_id=application_generate_entity.tenant_id)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
prompt_message, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
)
|
||||
invoke_result = assistant_cot_runner.run(
|
||||
model_instance=model_instance,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
)
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
assistant_fc_runner = AssistantFunctionCallApplicationRunner(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables
|
||||
)
|
||||
invoke_result = assistant_fc_runner.run(
|
||||
model_instance=model_instance,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tanent_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tanent_id
|
||||
).first()
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
db.session.add(tool_variables)
|
||||
else:
|
||||
# create new tool variables
|
||||
tool_variables = ToolConversationVariables(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tanent_id,
|
||||
variables_str='[]',
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(**{
|
||||
'conversation_id': db_variables.conversation_id,
|
||||
'user_id': db_variables.user_id,
|
||||
'tenant_id': db_variables.tenant_id,
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
||||
@@ -1,23 +1,18 @@
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity,
|
||||
ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity)
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.entities.application_entities import (ApplicationGenerateEntity, DatasetEntity,
|
||||
InvokeFrom, ModelConfigEntity)
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.moderation.base import ModerationException
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -146,7 +141,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
if app_orchestration_config.dataset:
|
||||
if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids:
|
||||
context = self.retrieve_dataset_context(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_record=app_record,
|
||||
@@ -213,76 +208,6 @@ class BasicApplicationRunner(AppRunner):
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config_entity: app orchestration config entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = ModerationFeature()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
external_data_fetch_feature = ExternalDataFetchFeature()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
def retrieve_dataset_context(self, tenant_id: str,
|
||||
app_record: App,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
@@ -334,31 +259,4 @@ class BasicApplicationRunner(AppRunner):
|
||||
hit_callback=hit_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
|
||||
@@ -5,20 +5,24 @@ from typing import Generator, Optional, Union, cast
|
||||
|
||||
from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import ApplicationGenerateEntity
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
|
||||
from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent,
|
||||
QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent,
|
||||
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent)
|
||||
QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent,
|
||||
QueueMessageFileEvent, QueueAgentMessageEvent)
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
|
||||
PromptMessage, PromptMessageContentType, PromptMessageRole,
|
||||
TextPromptMessageContent)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from pydantic import BaseModel
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@@ -135,6 +139,8 @@ class GenerateTaskPipeline:
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
@@ -145,12 +151,13 @@ class GenerateTaskPipeline:
|
||||
)
|
||||
|
||||
# Save message
|
||||
self._save_message(event.llm_result)
|
||||
self._save_message(self._task_state.llm_result)
|
||||
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'metadata': {},
|
||||
@@ -161,7 +168,7 @@ class GenerateTaskPipeline:
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._task_state.metadata
|
||||
response['metadata'] = self._get_response_metadata()
|
||||
|
||||
return response
|
||||
else:
|
||||
@@ -176,7 +183,9 @@ class GenerateTaskPipeline:
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
@@ -213,6 +222,8 @@ class GenerateTaskPipeline:
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
@@ -244,13 +255,14 @@ class GenerateTaskPipeline:
|
||||
'event': 'message_end',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._task_state.metadata
|
||||
response['metadata'] = self._get_response_metadata()
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
@@ -274,6 +286,7 @@ class GenerateTaskPipeline:
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
@@ -283,16 +296,48 @@ class GenerateTaskPipeline:
|
||||
'message_id': self._message.id,
|
||||
'position': agent_thought.position,
|
||||
'thought': agent_thought.thought,
|
||||
'observation': agent_thought.observation,
|
||||
'tool': agent_thought.tool,
|
||||
'tool_input': agent_thought.tool_input,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
'created_at': int(self._message.created_at.timestamp()),
|
||||
'message_files': agent_thought.files
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageEvent):
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
message_file: MessageFile = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
else:
|
||||
extension = '.bin'
|
||||
# add sign url
|
||||
url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
|
||||
|
||||
if message_file:
|
||||
response = {
|
||||
'event': 'message_file',
|
||||
'id': message_file.id,
|
||||
'type': message_file.type,
|
||||
'belongs_to': message_file.belongs_to or 'user',
|
||||
'url': url
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
||||
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
@@ -322,7 +367,7 @@ class GenerateTaskPipeline:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
response = self._handle_chunk(delta_text)
|
||||
response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent))
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
response = {
|
||||
@@ -374,14 +419,14 @@ class GenerateTaskPipeline:
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
||||
def _handle_chunk(self, text: str) -> dict:
|
||||
def _handle_chunk(self, text: str, agent: bool = False) -> dict:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'event': 'message' if not agent else 'agent_message',
|
||||
'id': self._message.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
@@ -410,6 +455,90 @@ class GenerateTaskPipeline:
|
||||
else:
|
||||
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
def _error_to_stream_response_data(self, e: Exception) -> dict:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
if isinstance(e, ValueError):
|
||||
data = {
|
||||
'code': 'invalid_param',
|
||||
'message': str(e),
|
||||
'status': 400
|
||||
}
|
||||
elif isinstance(e, ProviderTokenNotInitError):
|
||||
data = {
|
||||
'code': 'provider_not_initialize',
|
||||
'message': e.description,
|
||||
'status': 400
|
||||
}
|
||||
elif isinstance(e, QuotaExceededError):
|
||||
data = {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
}
|
||||
elif isinstance(e, ModelCurrentlyNotSupportError):
|
||||
data = {
|
||||
'code': 'model_currently_not_support',
|
||||
'message': e.description,
|
||||
'status': 400
|
||||
}
|
||||
elif isinstance(e, InvokeError):
|
||||
data = {
|
||||
'code': 'completion_request_error',
|
||||
'message': e.description,
|
||||
'status': 400
|
||||
}
|
||||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
}
|
||||
|
||||
return {
|
||||
'event': 'error',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
**data
|
||||
}
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
Get response metadata by invoke from.
|
||||
:return:
|
||||
"""
|
||||
metadata = {}
|
||||
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in self._task_state.metadata:
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
||||
else:
|
||||
metadata['retriever_resources'] = []
|
||||
for resource in self._task_state.metadata['retriever_resources']:
|
||||
metadata['retriever_resources'].append({
|
||||
'segment_id': resource['segment_id'],
|
||||
'position': resource['position'],
|
||||
'document_name': resource['document_name'],
|
||||
'score': resource['score'],
|
||||
'content': resource['content'],
|
||||
})
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in self._task_state.metadata:
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||
|
||||
# show usage
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['usage'] = self._task_state.metadata['usage']
|
||||
|
||||
return metadata
|
||||
|
||||
def _yield_response(self, response: dict) -> str:
|
||||
"""
|
||||
Yield response.
|
||||
|
||||
@@ -116,7 +116,7 @@ class OutputModerationHandler(BaseModel):
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
import uuid
|
||||
from typing import Any, Generator, Optional, Tuple, Union, cast
|
||||
|
||||
from core.app_runner.agent_app_runner import AgentApplicationRunner
|
||||
from core.app_runner.assistant_app_runner import AssistantApplicationRunner
|
||||
from core.app_runner.basic_app_runner import BasicApplicationRunner
|
||||
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
|
||||
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
@@ -13,7 +13,7 @@ from core.entities.application_entities import (AdvancedChatPromptTemplateEntity
|
||||
ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity,
|
||||
DatasetRetrieveConfigEntity, ExternalDataVariableEntity,
|
||||
FileUploadEntity, InvokeFrom, ModelConfigEntity, PromptTemplateEntity,
|
||||
SensitiveWordAvoidanceEntity)
|
||||
SensitiveWordAvoidanceEntity, AgentPromptEntity)
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file.file_obj import FileObj
|
||||
@@ -23,6 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
from extensions.ext_database import db
|
||||
from flask import Flask, current_app
|
||||
from models.account import Account
|
||||
@@ -93,6 +94,9 @@ class ApplicationManager:
|
||||
extras=extras
|
||||
)
|
||||
|
||||
if not stream and application_generate_entity.app_orchestration_config_entity.agent:
|
||||
raise ValueError("Agent app is not supported in blocking mode.")
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
@@ -151,7 +155,7 @@ class ApplicationManager:
|
||||
|
||||
if application_generate_entity.app_orchestration_config_entity.agent:
|
||||
# agent app
|
||||
runner = AgentApplicationRunner()
|
||||
runner = AssistantApplicationRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
@@ -354,6 +358,8 @@ class ApplicationManager:
|
||||
|
||||
# external data variables
|
||||
properties['external_data_variables'] = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
@@ -366,6 +372,19 @@ class ApplicationManager:
|
||||
config=external_data_tool['config']
|
||||
)
|
||||
)
|
||||
|
||||
# current external_data_tools
|
||||
for variable in copy_app_model_config_dict.get('user_input_form', []):
|
||||
typ = list(variable.keys())[0]
|
||||
if typ == 'external_data_tool':
|
||||
val = variable[typ]
|
||||
properties['external_data_variables'].append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=val['variable'],
|
||||
type=val['type'],
|
||||
config=val['config']
|
||||
)
|
||||
)
|
||||
|
||||
# show retrieve source
|
||||
show_retrieve_source = False
|
||||
@@ -375,15 +394,65 @@ class ApplicationManager:
|
||||
show_retrieve_source = True
|
||||
|
||||
properties['show_retrieve_source'] = show_retrieve_source
|
||||
|
||||
dataset_ids = []
|
||||
if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
|
||||
datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
|
||||
'strategy': 'router',
|
||||
'datasets': []
|
||||
})
|
||||
|
||||
|
||||
for dataset in datasets.get('datasets', []):
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != 'dataset':
|
||||
continue
|
||||
dataset = dataset['dataset']
|
||||
|
||||
if 'enabled' not in dataset or not dataset['enabled']:
|
||||
continue
|
||||
|
||||
dataset_id = dataset.get('id', None)
|
||||
if dataset_id:
|
||||
dataset_ids.append(dataset_id)
|
||||
else:
|
||||
datasets = {'strategy': 'router', 'datasets': []}
|
||||
|
||||
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
|
||||
'enabled']:
|
||||
agent_dict = copy_app_model_config_dict.get('agent_mode')
|
||||
agent_strategy = agent_dict.get('strategy', 'router')
|
||||
if agent_strategy in ['router', 'react_router']:
|
||||
dataset_ids = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] \
|
||||
and copy_app_model_config_dict['agent_mode']['enabled']:
|
||||
|
||||
agent_dict = copy_app_model_config_dict.get('agent_mode', {})
|
||||
agent_strategy = agent_dict.get('strategy', 'cot')
|
||||
|
||||
if agent_strategy == 'function_call':
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy == 'cot' or agent_strategy == 'react':
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if copy_app_model_config_dict['model']['provider'] == 'openai':
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
'provider_type': tool['provider_type'],
|
||||
'provider_id': tool['provider_id'],
|
||||
'tool_name': tool['tool_name'],
|
||||
'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
elif len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
@@ -396,59 +465,60 @@ class ApplicationManager:
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
||||
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
single_strategy=agent_strategy
|
||||
)
|
||||
|
||||
if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
|
||||
copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
|
||||
agent_prompt = agent_dict.get('prompt', None) or {}
|
||||
# check model mode
|
||||
model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
|
||||
if model_mode == 'completion':
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
|
||||
)
|
||||
else:
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
top_k=dataset_configs.get('top_k'),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model')
|
||||
)
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
||||
)
|
||||
else:
|
||||
if agent_strategy == 'react':
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
key = list(tool.keys())[0]
|
||||
tool_item = tool[key]
|
||||
|
||||
agent_tool_properties = {
|
||||
"tool_id": key
|
||||
}
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties["config"] = tool_item
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
||||
properties['agent'] = AgentEntity(
|
||||
provider=properties['model_config'].provider,
|
||||
model=properties['model_config'].model,
|
||||
strategy=strategy,
|
||||
tools=agent_tools
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get('max_iteration', 5)
|
||||
)
|
||||
|
||||
if len(dataset_ids) > 0:
|
||||
# dataset configs
|
||||
dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
||||
query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
||||
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
single_strategy=datasets.get('strategy', 'router')
|
||||
)
|
||||
)
|
||||
else:
|
||||
properties['dataset'] = DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
top_k=dataset_configs.get('top_k'),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model')
|
||||
)
|
||||
)
|
||||
|
||||
# file upload
|
||||
@@ -485,6 +555,12 @@ class ApplicationManager:
|
||||
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
|
||||
properties['speech_to_text'] = True
|
||||
|
||||
# text to speech
|
||||
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
|
||||
properties['text_to_speech'] = True
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
if sensitive_word_avoidance_dict:
|
||||
@@ -601,6 +677,7 @@ class ApplicationManager:
|
||||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
belongs_to='user',
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if account_id else 'end_user'),
|
||||
|
||||
@@ -7,10 +7,10 @@ from core.entities.application_entities import InvokeFrom
|
||||
from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentThoughtEvent, QueueErrorEvent,
|
||||
QueueMessage, QueueMessageEndEvent, QueueMessageEvent,
|
||||
QueueMessageReplaceEvent, QueuePingEvent, QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent)
|
||||
QueueStopEvent, QueueMessageFileEvent, QueueAgentMessageEvent)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAgentThought
|
||||
from models.model import MessageAgentThought, MessageFile
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
|
||||
@@ -96,6 +96,18 @@ class ApplicationQueueManager:
|
||||
chunk=chunk
|
||||
), pub_from)
|
||||
|
||||
def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent chunk message to channel
|
||||
|
||||
:param chunk: chunk
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentMessageEvent(
|
||||
chunk=chunk
|
||||
), pub_from)
|
||||
|
||||
def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish message replace
|
||||
@@ -144,6 +156,17 @@ class ApplicationQueueManager:
|
||||
agent_thought_id=message_agent_thought.id
|
||||
), pub_from)
|
||||
|
||||
def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_file: message file
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), pub_from)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
|
||||
74
api/core/callback_handler/agent_tool_callback_handler.py
Normal file
74
api/core/callback_handler/agent_tool_callback_handler.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
|
||||
class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
"""Callback Handler that prints to std out."""
|
||||
color: Optional[str] = ''
|
||||
current_loop = 1
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
"""Initialize callback handler."""
|
||||
# use a specific color is not specified
|
||||
self.color = color or 'green'
|
||||
self.current_loop = 1
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_outputs: str,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color=self.color)
|
||||
print_text("Tool: " + tool_name + "\n", color=self.color)
|
||||
print_text("Inputs: " + str(tool_inputs) + "\n", color=self.color)
|
||||
print_text("Outputs: " + str(tool_outputs) + "\n", color=self.color)
|
||||
print_text("\n")
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
|
||||
|
||||
def on_agent_start(
|
||||
self, thought: str
|
||||
) -> None:
|
||||
"""Run on agent start."""
|
||||
if thought:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + \
|
||||
str(self.current_loop) + \
|
||||
"\nThought: " + thought + "\n", color=self.color)
|
||||
else:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
def on_agent_finish(
|
||||
self, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
self.current_loop += 1
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
@@ -27,7 +27,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]:
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
@@ -36,7 +36,7 @@ class FileExtractor:
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
@@ -52,7 +52,7 @@ class FileExtractor:
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[List[Document] | str]:
|
||||
is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
@@ -68,7 +68,7 @@ class FileExtractor:
|
||||
else MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
@@ -95,7 +95,7 @@ class FileExtractor:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from typing import List
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTXLoader(BaseLoader):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from extensions.ext_database import db
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -18,47 +24,33 @@ class CacheEmbedding(Embeddings):
|
||||
self._user = user
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
# use doc embedding cache or store if not exists
|
||||
text_embeddings = [None for _ in range(len(texts))]
|
||||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
||||
if embedding:
|
||||
text_embeddings[i] = embedding.get_embedding()
|
||||
else:
|
||||
embedding_queue_indices.append(i)
|
||||
"""Embed search docs in batches of 10."""
|
||||
text_embeddings = []
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
|
||||
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
|
||||
for i in range(0, len(texts), max_chunks):
|
||||
batch_texts = texts[i:i + max_chunks]
|
||||
|
||||
if embedding_queue_indices:
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[texts[i] for i in embedding_queue_indices],
|
||||
texts=batch_texts,
|
||||
user=self._user
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings
|
||||
except Exception as ex:
|
||||
logger.error('Failed to embed documents: ', ex)
|
||||
raise ex
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
text_embeddings.append(normalized_embedding)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
logging.exception('Failed to add embedding to redis')
|
||||
|
||||
for i, indice in enumerate(embedding_queue_indices):
|
||||
hash = helper.generate_text_hash(texts[indice])
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
||||
vector = embedding_results[i]
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
text_embeddings[indice] = normalized_embedding
|
||||
embedding.set_embedding(normalized_embedding)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
continue
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
continue
|
||||
except Exception as ex:
|
||||
logger.error('Failed to embed documents: ', ex)
|
||||
raise ex
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@@ -66,9 +58,12 @@ class CacheEmbedding(Embeddings):
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
|
||||
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
return embedding.get_embedding()
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
||||
|
||||
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
@@ -82,13 +77,18 @@ class CacheEmbedding(Embeddings):
|
||||
raise ex
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
|
||||
embedding.set_embedding(embedding_results)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
# encode embedding to base64
|
||||
embedding_vector = np.array(embedding_results)
|
||||
vector_bytes = embedding_vector.tobytes()
|
||||
# Transform to Base64
|
||||
encoded_vector = base64.b64encode(vector_bytes)
|
||||
# Transform to string
|
||||
encoded_str = encoded_vector.decode("utf-8")
|
||||
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
logging.exception('Failed to add embedding to redis')
|
||||
|
||||
return embedding_results
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Optional, Any, cast, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelConfigEntity(BaseModel):
|
||||
@@ -153,9 +154,35 @@ class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
tool_id: str
|
||||
config: dict[str, Any] = {}
|
||||
provider_type: Literal["builtin", "api"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Agent First Prompt Entity.
|
||||
"""
|
||||
|
||||
class Action(BaseModel):
|
||||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
@@ -171,8 +198,9 @@ class AgentEntity(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
strategy: Strategy
|
||||
tools: list[AgentToolEntity] = []
|
||||
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
class AppOrchestrationConfigEntity(BaseModel):
|
||||
"""
|
||||
@@ -191,6 +219,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: bool = False
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
@@ -255,7 +284,6 @@ class ApplicationGenerateEntity(BaseModel):
|
||||
query: Optional[str] = None
|
||||
files: list[FileObj] = []
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
stream: bool
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
@@ -165,7 +165,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
model_provider_factory.provider_credentials_validate(
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
self.provider.provider,
|
||||
credentials
|
||||
)
|
||||
@@ -308,24 +308,13 @@ class ProviderConfiguration(BaseModel):
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
model_provider_factory.model_credentials_validate(
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
model_schema = (
|
||||
model_provider_factory.get_provider_instance(self.provider.provider)
|
||||
.get_model_instance(model_type)._get_customizable_model_schema(
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
|
||||
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
@@ -10,11 +10,13 @@ class QueueEvent(Enum):
|
||||
QueueEvent enum
|
||||
"""
|
||||
MESSAGE = "message"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
MESSAGE_REPLACE = "message-replace"
|
||||
MESSAGE_END = "message-end"
|
||||
RETRIEVER_RESOURCES = "retriever-resources"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
MESSAGE_FILE = "message-file"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
@@ -33,7 +35,14 @@ class QueueMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
event = QueueEvent.MESSAGE
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEvent entity
|
||||
"""
|
||||
event = QueueEvent.AGENT_MESSAGE
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
class QueueMessageReplaceEvent(AppQueueEvent):
|
||||
"""
|
||||
@@ -73,7 +82,13 @@ class QueueAgentThoughtEvent(AppQueueEvent):
|
||||
"""
|
||||
event = QueueEvent.AGENT_THOUGHT
|
||||
agent_thought_id: str
|
||||
|
||||
|
||||
class QueueMessageFileEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE_FILE
|
||||
message_file_id: str
|
||||
|
||||
class QueueErrorEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
@@ -1,30 +1,27 @@
|
||||
import logging
|
||||
from typing import List, Optional, cast
|
||||
from typing import cast, Optional, List
|
||||
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.tools import BaseTool, WikipediaQueryRun, Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
|
||||
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.entities.application_entities import (AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity, InvokeFrom,
|
||||
ModelConfigEntity)
|
||||
from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
|
||||
AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tool.current_datetime_tool import DatetimeTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIInput, OptimizedSerpAPIWrapper
|
||||
from core.tool.web_reader_tool import WebReaderTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
|
||||
from models.dataset import Dataset
|
||||
from models.model import Message
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,55 +129,6 @@ class AgentRunnerFeature:
|
||||
logger.exception("agent_executor run failed")
|
||||
return None
|
||||
|
||||
def to_tools(self, tool_configs: list[AgentToolEntity],
|
||||
invoke_from: InvokeFrom,
|
||||
callbacks: list[BaseCallbackHandler]) \
|
||||
-> Optional[List[BaseTool]]:
|
||||
"""
|
||||
Convert tool configs to tools
|
||||
:param tool_configs: tool configs
|
||||
:param invoke_from: invoke from
|
||||
:param callbacks: callbacks
|
||||
"""
|
||||
tools = []
|
||||
for tool_config in tool_configs:
|
||||
tool = None
|
||||
if tool_config.tool_id == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "web_reader":
|
||||
tool = self.to_web_reader_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "google_search":
|
||||
tool = self.to_google_search_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "wikipedia":
|
||||
tool = self.to_wikipedia_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
elif tool_config.tool_id == "current_datetime":
|
||||
tool = self.to_current_datetime_tool(
|
||||
tool_config=tool_config.config,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
if tool:
|
||||
if tool.callbacks is not None:
|
||||
tool.callbacks.extend(callbacks)
|
||||
else:
|
||||
tool.callbacks = callbacks
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) \
|
||||
-> Optional[BaseTool]:
|
||||
@@ -247,78 +195,4 @@ class AgentRunnerFeature:
|
||||
retriever_from=invoke_from.to_source()
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_web_reader_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for reading web pages
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
model_parameters = {
|
||||
"temperature": 0,
|
||||
"max_tokens": 500
|
||||
}
|
||||
|
||||
tool = WebReaderTool(
|
||||
model_config=self.model_config,
|
||||
model_parameters=model_parameters,
|
||||
max_chunk_length=4000,
|
||||
continue_reading=True
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_google_search_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for performing a Google search and extracting snippets and webpages
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
||||
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
||||
if not func_kwargs:
|
||||
return None
|
||||
|
||||
tool = Tool(
|
||||
name="google_search",
|
||||
description="A tool for performing a Google search and extracting snippets and webpages "
|
||||
"when you need to search for something you don't know or when your information "
|
||||
"is not up to date. "
|
||||
"Input should be a search query.",
|
||||
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
|
||||
args_schema=OptimizedSerpAPIInput
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def to_current_datetime_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for getting the current date and time
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
return DatetimeTool()
|
||||
|
||||
def to_wikipedia_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) -> Optional[BaseTool]:
|
||||
"""
|
||||
A tool for searching Wikipedia
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
class WikipediaInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
return WikipediaQueryRun(
|
||||
name="wikipedia",
|
||||
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||
args_schema=WikipediaInput
|
||||
)
|
||||
return tool
|
||||
558
api/core/features/assistant_base_runner.py
Normal file
558
api/core/features/assistant_base_runner.py
Normal file
@@ -0,0 +1,558 @@
|
||||
import logging
|
||||
import json
|
||||
|
||||
from typing import Optional, List, Tuple, Union
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from extensions.ext_database import db
|
||||
|
||||
from models.model import MessageAgentThought, Message, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, \
|
||||
ToolRuntimeVariablePool, ToolParamter
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ModelConfigEntity, AgentEntity, AgentToolEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.entities.application_entities import ModelConfigEntity, \
|
||||
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.file.message_file_parser import FileTransferMethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAssistantApplicationRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: ApplicationGenerateEntity,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
model_config: ModelConfigEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.app_orchestration_config = app_orchestration_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = prompt_messages
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
|
||||
# init callback
|
||||
self.agent_callback = DifyAgentCallbackHandler()
|
||||
# init dataset tools
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=queue_manager,
|
||||
app_id=self.application_generate_entity.app_id,
|
||||
message_id=message.id,
|
||||
user_id=user_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [],
|
||||
retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None,
|
||||
return_resource=app_orchestration_config.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
|
||||
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
|
||||
"""
|
||||
Repacket app orchestration config
|
||||
"""
|
||||
if app_orchestration_config.prompt_template.simple_prompt_template is None:
|
||||
app_orchestration_config.prompt_template.simple_prompt_template = ''
|
||||
|
||||
return app_orchestration_config
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ''
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please dirct user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += f"image has been created and sent to user already, you should tell user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
|
||||
tanent_id=self.application_generate_entity.tenant_id,
|
||||
agent_callback=self.agent_callback
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_entity.parameters or []
|
||||
user_parameters = tool_entity.get_runtime_parameters() or []
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
for parameter in parameters:
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParamter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParamter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParamter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParamter.ToolParameterType.SELECT, ToolParamter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParamter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
elif parameter.form == ToolParamter.ToolParameterForm.LLM:
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name,
|
||||
description=tool.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
for parameter in tool.get_runtime_parameters():
|
||||
parameter_type = 'string'
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||
"""
|
||||
update prompt message tool
|
||||
"""
|
||||
# try to get tool runtime parameters
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParamter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParamter.ToolParameterForm.LLM:
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and 'mime_type' in response.meta:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:param messages: messages
|
||||
:return: message files, should save as variable
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
file_type = 'bin'
|
||||
if 'image' in message.mimetype:
|
||||
file_type = 'image'
|
||||
elif 'video' in message.mimetype:
|
||||
file_type = 'video'
|
||||
elif 'audio' in message.mimetype:
|
||||
file_type = 'audio'
|
||||
elif 'text' in message.mimetype:
|
||||
file_type = 'text'
|
||||
elif 'pdf' in message.mimetype:
|
||||
file_type = 'pdf'
|
||||
elif 'zip' in message.mimetype:
|
||||
file_type = 'archive'
|
||||
# ...
|
||||
|
||||
invoke_from = self.application_generate_entity.invoke_from
|
||||
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE.value,
|
||||
belongs_to='assistant',
|
||||
url=message.url,
|
||||
upload_file_id=None,
|
||||
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self.user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: List[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
"""
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
tool=tool_name,
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_files=json.dumps(messages_ids) if messages_ids else '',
|
||||
answer='',
|
||||
observation='',
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
position=self.agent_thought_count + 1,
|
||||
currency='USD',
|
||||
latency=0,
|
||||
created_by_role='account',
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
db.session.add(thought)
|
||||
db.session.commit()
|
||||
|
||||
self.agent_thought_count += 1
|
||||
|
||||
return thought
|
||||
|
||||
def save_agent_thought(self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: str,
|
||||
answer: str,
|
||||
messages_ids: List[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
||||
if tool_name is not None:
|
||||
agent_thought.tool = tool_name
|
||||
|
||||
if tool_input is not None:
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
|
||||
if observation is not None:
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages
|
||||
"""
|
||||
if self.history_prompt_messages is None:
|
||||
self.history_prompt_messages = db.session.query(PromptMessage).filter(
|
||||
PromptMessage.message_id == self.message.id,
|
||||
).order_by(PromptMessage.position.asc()).all()
|
||||
|
||||
return self.history_prompt_messages
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message into agent thought
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# try to download image
|
||||
try:
|
||||
file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
|
||||
conversation_id=self.message.conversation_id,
|
||||
file_url=message.message)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
mimetype = message.meta.get('mime_type', 'octet/stream')
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode('utf-8')
|
||||
file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
|
||||
conversation_id=self.message.conversation_id,
|
||||
file_binary=message.message,
|
||||
mimetype=mimetype)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
581
api/core/features/assistant_cot_runner.py
Normal file
581
api/core/features/assistant_cot_runner.py
Normal file
@@ -0,0 +1,581 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal, Union, Generator, Dict, List
|
||||
|
||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, \
|
||||
UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
|
||||
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
|
||||
ToolProviderCredentialValidationError
|
||||
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
|
||||
from models.model import Conversation, Message
|
||||
|
||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, model_instance: ModelInstance,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
self._repacket_app_orchestration_config(app_orchestration_config)
|
||||
|
||||
agent_scratchpad: List[AgentScratchpadUnit] = []
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
# TODO: stop words
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
app_orchestration_config.model_config.stop.append('Observation')
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
|
||||
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increse_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt messages
|
||||
prompt_messages = self._originze_cot_prompt_messages(
|
||||
mode=app_orchestration_config.model_config.mode,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_messages_tools,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
agent_prompt_message=app_orchestration_config.agent.prompt,
|
||||
instruction=app_orchestration_config.prompt_template.simple_prompt_template,
|
||||
input=query
|
||||
)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
llm_result: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=[],
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not llm_result:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
# get scratchpad
|
||||
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if llm_result.usage:
|
||||
increse_usage(llm_usage, llm_result.usage)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=llm_result.message.content,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_result.usage)
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# publish agent thought if it's not empty and there is a action
|
||||
if scratchpad.thought and scratchpad.action:
|
||||
# check if final answer
|
||||
if not scratchpad.action.action_name.lower() == "final answer":
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=scratchpad.thought
|
||||
),
|
||||
usage=llm_result.usage,
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = scratchpad.agent_response or ''
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
final_answer = scratchpad.action.action_input if \
|
||||
isinstance(scratchpad.action.action_input, str) else \
|
||||
json.dumps(scratchpad.action.action_input)
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
else:
|
||||
function_call_state = True
|
||||
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = scratchpad.action.action_name
|
||||
tool_call_args = scratchpad.action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
thought=None,
|
||||
observation=answer,
|
||||
answer=answer,
|
||||
messages_ids=[])
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
tool_response = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_paramters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
|
||||
)
|
||||
# transform tool response to llm friendly response
|
||||
tool_response = self.transform_tool_invoke_messages(tool_response)
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = self.extract_tool_response_binary(tool_response)
|
||||
# create message file
|
||||
message_files = self.create_message_files(binary_files)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name,
|
||||
value=message_file.id,
|
||||
name=save_as)
|
||||
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Plese check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool_call_name}"
|
||||
except (
|
||||
ToolParamterValidationError
|
||||
) as e:
|
||||
error_response = f"tool paramters validation error: {e}, please check your tool paramters"
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
|
||||
if error_response:
|
||||
observation = error_response
|
||||
else:
|
||||
observation = self._convert_tool_response_to_str(tool_response)
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
scratchpad.agent_response = llm_result.message.content
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_name,
|
||||
tool_input=tool_call_args,
|
||||
thought=None,
|
||||
observation=observation,
|
||||
answer=llm_result.message.content,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
thought=final_answer,
|
||||
observation='',
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish_message_end(LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage'],
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
||||
"""
|
||||
extract response from llm response
|
||||
"""
|
||||
def extra_quotes() -> AgentScratchpadUnit:
|
||||
agent_response = content
|
||||
# try to extract all quotes
|
||||
pattern = re.compile(r'```(.*?)```', re.DOTALL)
|
||||
quotes = pattern.findall(content)
|
||||
|
||||
# try to extract action from end to start
|
||||
for i in range(len(quotes) - 1, 0, -1):
|
||||
"""
|
||||
1. use json load to parse action
|
||||
2. use plain text `Action: xxx` to parse action
|
||||
"""
|
||||
try:
|
||||
action = json.loads(quotes[i].replace('```', ''))
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
# try to parse action from plain text
|
||||
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
|
||||
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
|
||||
# delete action from agent response
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
# remove extra quotes
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name[0],
|
||||
action_input=action_input[0],
|
||||
)
|
||||
)
|
||||
|
||||
def extra_json():
|
||||
agent_response = content
|
||||
# try to extract all json
|
||||
structures, pair_match_stack = [], []
|
||||
started_at, end_at = 0, 0
|
||||
for i in range(len(content)):
|
||||
if content[i] == '{':
|
||||
pair_match_stack.append(i)
|
||||
if len(pair_match_stack) == 1:
|
||||
started_at = i
|
||||
elif content[i] == '}':
|
||||
begin = pair_match_stack.pop()
|
||||
if not pair_match_stack:
|
||||
end_at = i + 1
|
||||
structures.append((content[begin:i+1], (started_at, end_at)))
|
||||
|
||||
# handle the last character
|
||||
if pair_match_stack:
|
||||
end_at = len(content)
|
||||
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
|
||||
|
||||
for i in range(len(structures), 0, -1):
|
||||
try:
|
||||
json_content, (started_at, end_at) = structures[i - 1]
|
||||
action = json.loads(json_content)
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
# delete json content from agent response
|
||||
agent_thought = agent_response[:started_at] + agent_response[end_at:]
|
||||
# remove extra quotes like ```(json)*\n\n```
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=json_content,
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad = extra_quotes()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
agent_scratchpad = extra_json()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=content,
|
||||
action_str='',
|
||||
action=None
|
||||
)
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
):
|
||||
"""
|
||||
check chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
next_iteration = agent_prompt_message.next_iteration
|
||||
|
||||
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
|
||||
raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
|
||||
|
||||
# check instruction, tools, and tool_names slots
|
||||
if not first_prompt.find("{{instruction}}") >= 0:
|
||||
raise ValueError("{{instruction}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tools}}") >= 0:
|
||||
raise ValueError("{{tools}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tool_names}}") >= 0:
|
||||
raise ValueError("{{tool_names}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not first_prompt.find("{{query}}") >= 0:
|
||||
raise ValueError("{{query}} is required in first_prompt")
|
||||
if not first_prompt.find("{{agent_scratchpad}}") >= 0:
|
||||
raise ValueError("{{agent_scratchpad}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_strachpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
|
||||
|
||||
return result
|
||||
|
||||
def _originze_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
agent_scratchpad: List[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> List[PromptMessage]:
|
||||
"""
|
||||
originze chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
self._check_cot_prompt_messages(mode, agent_prompt_message)
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
|
||||
# parse tools
|
||||
tools_str = self._jsonify_tool_prompt_messages(tools)
|
||||
|
||||
# parse tools name
|
||||
tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
|
||||
|
||||
# get system message
|
||||
system_message = first_prompt.replace("{{instruction}}", instruction) \
|
||||
.replace("{{tools}}", tools_str) \
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
|
||||
# originze prompt messages
|
||||
if mode == "chat":
|
||||
# override system message
|
||||
overrided = False
|
||||
prompt_messages = prompt_messages.copy()
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message.content = system_message
|
||||
overrided = True
|
||||
break
|
||||
|
||||
if not overrided:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=system_message,
|
||||
))
|
||||
|
||||
# add assistant message
|
||||
if len(agent_scratchpad) > 0:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=(agent_scratchpad[-1].thought or '')
|
||||
))
|
||||
|
||||
# add user message
|
||||
if len(agent_scratchpad) > 0:
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=(agent_scratchpad[-1].observation or ''),
|
||||
))
|
||||
|
||||
return prompt_messages
|
||||
elif mode == "completion":
|
||||
# parse agent scratchpad
|
||||
agent_scratchpad_str = self._convert_strachpad_list_to_str(agent_scratchpad)
|
||||
# parse prompt messages
|
||||
return [UserPromptMessage(
|
||||
content=first_prompt.replace("{{instruction}}", instruction)
|
||||
.replace("{{tools}}", tools_str)
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
.replace("{{query}}", input)
|
||||
.replace("{{agent_scratchpad}}", agent_scratchpad_str),
|
||||
)]
|
||||
else:
|
||||
raise ValueError(f"mode {mode} is not supported")
|
||||
|
||||
def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
|
||||
"""
|
||||
jsonify tool prompt messages
|
||||
"""
|
||||
tools = jsonable_encoder(tools)
|
||||
try:
|
||||
return json.dumps(tools, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps(tools)
|
||||
333
api/core/features/assistant_fc_runner.py
Normal file
333
api/core/features/assistant_fc_runner.py
Normal file
@@ -0,0 +1,333 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Union, Generator, Dict, Any, Tuple, List
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
|
||||
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
|
||||
from core.model_manager import ModelInstance
|
||||
from core.application_queue_manager import PublishFrom
|
||||
|
||||
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
|
||||
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
|
||||
ToolProviderCredentialValidationError
|
||||
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, model_instance: ModelInstance,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
|
||||
prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_messages = self.history_prompt_messages
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=query,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: List[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
for chunk in chunks:
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += chunk.delta.message.content
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
)
|
||||
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}"
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
else:
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
tool_invoke_message = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_paramters=tool_call_args,
|
||||
)
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = self.extract_tool_response_binary(tool_invoke_message)
|
||||
# create message file
|
||||
message_files = self.create_message_files(binary_files)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Plese check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool_call_name}"
|
||||
except (
|
||||
ToolParamterValidationError
|
||||
) as e:
|
||||
error_response = f"tool paramters validation error: {e}, please check your tool paramters"
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
|
||||
if error_response:
|
||||
observation = error_response
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": error_response
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
else:
|
||||
observation = self._convert_tool_response_to_str(tool_invoke_message)
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": observation
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=None,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_response=tool_response['tool_response'],
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_response['tool_response'],
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt messages
|
||||
if response.strip():
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=response,
|
||||
))
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish_message_end(LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer,
|
||||
),
|
||||
usage=llm_usage['usage'],
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
Check if there is any tool call in llm result chunk
|
||||
"""
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
json.loads(prompt_message.function.arguments),
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
def organize_prompt_messages(self, prompt_template: str,
|
||||
query: str = None,
|
||||
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
else:
|
||||
if tool_response:
|
||||
prompt_messages = prompt_messages.copy()
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
@@ -6,8 +6,8 @@ from core.entities.application_entities import DatasetEntity, DatasetRetrieveCon
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from langchain.tools import BaseTool
|
||||
from models.dataset import Dataset
|
||||
@@ -166,8 +166,7 @@ class DatasetRetrievalFeature:
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
score_threshold=(retrieve_config.score_threshold or 0.5)
|
||||
if retrieve_config.reranking_model.get('score_threshold_enabled', False) else None,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
|
||||
@@ -22,6 +22,7 @@ class FileType(enum.Enum):
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = 'remote_url'
|
||||
LOCAL_FILE = 'local_file'
|
||||
TOOL_FILE = 'tool_file'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
@@ -30,6 +31,16 @@ class FileTransferMethod(enum.Enum):
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
class FileBelongsTo(enum.Enum):
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
class FileObj(BaseModel):
|
||||
id: Optional[str]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from core.file.file_obj import FileObj, FileTransferMethod, FileType
|
||||
from core.file.upload_file_parser import SUPPORT_EXTENSIONS
|
||||
from core.file.file_obj import FileObj, FileTransferMethod, FileType, FileBelongsTo
|
||||
from services.file_service import IMAGE_EXTENSIONS
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import AppModelConfig, EndUser, MessageFile, UploadFile
|
||||
@@ -83,7 +83,7 @@ class MessageFileParser:
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
UploadFile.extension.in_(SUPPORT_EXTENSIONS)
|
||||
UploadFile.extension.in_(IMAGE_EXTENSIONS)
|
||||
).first())
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
@@ -128,6 +128,10 @@ class MessageFileParser:
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
if isinstance(file, MessageFile):
|
||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
||||
continue
|
||||
|
||||
file_obj = self._to_file_obj(file, file_upload_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
8
api/core/file/tool_file_parser.py
Normal file
8
api/core/file/tool_file_parser.py
Normal file
@@ -0,0 +1,8 @@
|
||||
tool_file_manager = {
|
||||
'manager': None
|
||||
}
|
||||
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> 'ToolFileManager':
|
||||
return tool_file_manager['manager']
|
||||
@@ -9,8 +9,8 @@ from typing import Optional
|
||||
from extensions.ext_storage import storage
|
||||
from flask import current_app
|
||||
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
@@ -18,7 +18,7 @@ class UploadFileParser:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in SUPPORT_EXTENSIONS:
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from core.entities.provider_entities import QuotaUnit, RestrictModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from flask import Flask
|
||||
from flask import Flask, Config
|
||||
from models.provider import ProviderQuotaType
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -48,46 +47,47 @@ class HostingConfiguration:
|
||||
moderation_config: HostedModerationConfig = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
if app.config.get('EDITION') != 'CLOUD':
|
||||
config = app.config
|
||||
|
||||
if config.get('EDITION') != 'CLOUD':
|
||||
return
|
||||
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai()
|
||||
self.provider_map["openai"] = self.init_openai()
|
||||
self.provider_map["anthropic"] = self.init_anthropic()
|
||||
self.provider_map["minimax"] = self.init_minimax()
|
||||
self.provider_map["spark"] = self.init_spark()
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai()
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai(config)
|
||||
self.provider_map["openai"] = self.init_openai(config)
|
||||
self.provider_map["anthropic"] = self.init_anthropic(config)
|
||||
self.provider_map["minimax"] = self.init_minimax(config)
|
||||
self.provider_map["spark"] = self.init_spark(config)
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai(config)
|
||||
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
self.moderation_config = self.init_moderation_config(config)
|
||||
|
||||
def init_azure_openai(self) -> HostingProvider:
|
||||
def init_azure_openai(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TIMES
|
||||
if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true':
|
||||
if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
|
||||
credentials = {
|
||||
"openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"),
|
||||
"openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"),
|
||||
"openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"),
|
||||
"base_model_name": "gpt-35-turbo"
|
||||
}
|
||||
|
||||
quotas = []
|
||||
hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
|
||||
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
|
||||
]
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
|
||||
]
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
@@ -101,43 +101,44 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_openai(self) -> HostingProvider:
|
||||
def init_openai(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TIMES
|
||||
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
RestrictModel(model="whisper-1", model_type=ModelType.SPEECH2TEXT),
|
||||
]
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
|
||||
paid_quota = PaidHostingQuota(
|
||||
stripe_price_id=app_config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||
increase_quota=int(app_config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
|
||||
min_quantity=int(app_config.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
|
||||
max_quantity=int(app_config.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"openai_api_key": os.environ.get("HOSTED_OPENAI_API_KEY"),
|
||||
"openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"),
|
||||
}
|
||||
|
||||
if os.environ.get("HOSTED_OPENAI_API_BASE"):
|
||||
credentials["openai_api_base"] = os.environ.get("HOSTED_OPENAI_API_BASE")
|
||||
if app_config.get("HOSTED_OPENAI_API_BASE"):
|
||||
credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE")
|
||||
|
||||
if os.environ.get("HOSTED_OPENAI_API_ORGANIZATION"):
|
||||
credentials["openai_organization"] = os.environ.get("HOSTED_OPENAI_API_ORGANIZATION")
|
||||
|
||||
quotas = []
|
||||
hosted_quota_limit = int(os.environ.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
]
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if os.environ.get("HOSTED_OPENAI_PAID_ENABLED") and os.environ.get(
|
||||
"HOSTED_OPENAI_PAID_ENABLED").lower() == 'true':
|
||||
paid_quota = PaidHostingQuota(
|
||||
stripe_price_id=os.environ.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
|
||||
increase_quota=int(os.environ.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
|
||||
min_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
|
||||
max_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"):
|
||||
credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION")
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
@@ -151,33 +152,33 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_anthropic(self) -> HostingProvider:
|
||||
def init_anthropic(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if os.environ.get("HOSTED_ANTHROPIC_ENABLED") and os.environ.get("HOSTED_ANTHROPIC_ENABLED").lower() == 'true':
|
||||
quotas = []
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"):
|
||||
paid_quota = PaidHostingQuota(
|
||||
stripe_price_id=app_config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
increase_quota=int(app_config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
|
||||
min_quantity=int(app_config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
|
||||
max_quantity=int(app_config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"anthropic_api_key": os.environ.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
"anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"),
|
||||
}
|
||||
|
||||
if os.environ.get("HOSTED_ANTHROPIC_API_BASE"):
|
||||
credentials["anthropic_api_url"] = os.environ.get("HOSTED_ANTHROPIC_API_BASE")
|
||||
|
||||
quotas = []
|
||||
hosted_quota_limit = int(os.environ.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
|
||||
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if os.environ.get("HOSTED_ANTHROPIC_PAID_ENABLED") and os.environ.get(
|
||||
"HOSTED_ANTHROPIC_PAID_ENABLED").lower() == 'true':
|
||||
paid_quota = PaidHostingQuota(
|
||||
stripe_price_id=os.environ.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
|
||||
increase_quota=int(os.environ.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
|
||||
min_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
|
||||
max_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
if app_config.get("HOSTED_ANTHROPIC_API_BASE"):
|
||||
credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE")
|
||||
|
||||
return HostingProvider(
|
||||
enabled=True,
|
||||
@@ -191,9 +192,9 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_minimax(self) -> HostingProvider:
|
||||
def init_minimax(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if os.environ.get("HOSTED_MINIMAX_ENABLED") and os.environ.get("HOSTED_MINIMAX_ENABLED").lower() == 'true':
|
||||
if app_config.get("HOSTED_MINIMAX_ENABLED"):
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@@ -208,9 +209,9 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_spark(self) -> HostingProvider:
|
||||
def init_spark(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if os.environ.get("HOSTED_SPARK_ENABLED") and os.environ.get("HOSTED_SPARK_ENABLED").lower() == 'true':
|
||||
if app_config.get("HOSTED_SPARK_ENABLED"):
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@@ -225,9 +226,9 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_zhipuai(self) -> HostingProvider:
|
||||
def init_zhipuai(self, app_config: Config) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
if os.environ.get("HOSTED_ZHIPUAI_ENABLED") and os.environ.get("HOSTED_ZHIPUAI_ENABLED").lower() == 'true':
|
||||
if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
|
||||
quotas = [FreeHostingQuota()]
|
||||
|
||||
return HostingProvider(
|
||||
@@ -242,12 +243,12 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_moderation_config(self) -> HostedModerationConfig:
|
||||
if os.environ.get("HOSTED_MODERATION_ENABLED") and os.environ.get("HOSTED_MODERATION_ENABLED").lower() == 'true' \
|
||||
and os.environ.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
def init_moderation_config(self, app_config: Config) -> HostedModerationConfig:
|
||||
if app_config.get("HOSTED_MODERATION_ENABLED") \
|
||||
and app_config.get("HOSTED_MODERATION_PROVIDERS"):
|
||||
return HostedModerationConfig(
|
||||
enabled=True,
|
||||
providers=os.environ.get("HOSTED_MODERATION_PROVIDERS").split(',')
|
||||
providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',')
|
||||
)
|
||||
|
||||
return HostedModerationConfig(
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_manager import ModelManager, ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
@@ -61,8 +61,24 @@ class IndexingRunner:
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._step_split(
|
||||
@@ -121,8 +137,24 @@ class IndexingRunner:
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._step_split(
|
||||
@@ -242,6 +274,8 @@ class IndexingRunner:
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
for file_detail in file_details:
|
||||
|
||||
processing_rule = DatasetProcessRule(
|
||||
@@ -253,7 +287,7 @@ class IndexingRunner:
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
@@ -312,11 +346,13 @@ class IndexingRunner:
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
|
||||
"currency": embedding_price_info.currency if embedding_model_instance else 'USD',
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
@@ -356,6 +392,8 @@ class IndexingRunner:
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
@@ -384,7 +422,7 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
splitter = self._get_splitter(processing_rule, embedding_model_instance)
|
||||
|
||||
# split to documents
|
||||
documents = self._split_to_documents_for_estimate(
|
||||
@@ -438,20 +476,22 @@ class IndexingRunner:
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0,
|
||||
"currency": embedding_price_info.currency if embedding_model_instance else 'USD',
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
@@ -499,10 +539,13 @@ class IndexingRunner:
|
||||
def filter_string(self, text):
|
||||
text = re.sub(r'<\|', '<', text)
|
||||
text = re.sub(r'\|>', '>', text)
|
||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
|
||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
|
||||
# Unicode U+FFFE
|
||||
text = re.sub(u'\uFFFE', '', text)
|
||||
return text
|
||||
|
||||
def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
|
||||
def _get_splitter(self, processing_rule: DatasetProcessRule,
|
||||
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
@@ -517,19 +560,20 @@ class IndexingRunner:
|
||||
if separator:
|
||||
separator = separator.replace('\\n', '\n')
|
||||
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=0,
|
||||
fixed_separator=separator,
|
||||
separators=["\n\n", "。", ".", " ", ""]
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
)
|
||||
else:
|
||||
# Automatic segmentation
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||
chunk_overlap=0,
|
||||
separators=["\n\n", "。", ".", " ", ""]
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
)
|
||||
|
||||
return character_splitter
|
||||
@@ -714,7 +758,7 @@ class IndexingRunner:
|
||||
return text
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
|
||||
matches = re.findall(regex, text, re.UNICODE)
|
||||
|
||||
return [
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
@@ -144,7 +145,7 @@ class ModelInstance:
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
@@ -161,8 +162,29 @@ class ModelInstance:
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:param streaming: output is streaming
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception(f"Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
**params
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class ModelType(Enum):
|
||||
RERANK = "rerank"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
# TTS = "tts"
|
||||
TTS = "tts"
|
||||
# TEXT2IMG = "text2img"
|
||||
|
||||
@classmethod
|
||||
@@ -33,6 +33,8 @@ class ModelType(Enum):
|
||||
return cls.RERANK
|
||||
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
|
||||
return cls.TTS
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
@@ -52,6 +54,8 @@ class ModelType(Enum):
|
||||
return 'reranking'
|
||||
elif self == self.SPEECH2TEXT:
|
||||
return 'speech2text'
|
||||
elif self == self.TTS:
|
||||
return 'tts'
|
||||
elif self == self.MODERATION:
|
||||
return 'moderation'
|
||||
else:
|
||||
@@ -120,6 +124,10 @@ class ModelPropertyKey(Enum):
|
||||
FILE_UPLOAD_LIMIT = "file_upload_limit"
|
||||
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
|
||||
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
|
||||
DEFAULT_VOICE = "default_voice"
|
||||
WORD_LIMIT = "word_limit"
|
||||
AUDOI_TYPE = "audio_type"
|
||||
MAX_WORKERS = "max_workers"
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
@@ -149,8 +157,8 @@ class ParameterRule(BaseModel):
|
||||
help: Optional[I18nObject] = None
|
||||
required: bool = False
|
||||
default: Optional[Any] = None
|
||||
min: Optional[float | int] = None
|
||||
max: Optional[float | int] = None
|
||||
min: Optional[float] = None
|
||||
max: Optional[float] = None
|
||||
precision: Optional[int] = None
|
||||
options: list[str] = []
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import decimal
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
@@ -12,7 +10,6 @@ from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultPa
|
||||
PriceConfig, PriceInfo, PriceType)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class AIModel(ABC):
|
||||
@@ -54,14 +51,16 @@ class AIModel(ABC):
|
||||
:param error: model invoke error
|
||||
:return: unified error
|
||||
"""
|
||||
provider_name = self.__class__.__module__.split('.')[-3]
|
||||
|
||||
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
||||
if isinstance(error, tuple(model_errors)):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return invoke_error(description="Incorrect model credentials provided, please check and try again. ")
|
||||
return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ")
|
||||
|
||||
return invoke_error(description=f"{invoke_error.description}: {str(error)}")
|
||||
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
|
||||
|
||||
return InvokeError(description=f"Error: {str(error)}")
|
||||
return InvokeError(description=f"[{provider_name}] Error: {str(error)}")
|
||||
|
||||
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Generator, List, Optional, Union
|
||||
@@ -212,6 +213,10 @@ class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
|
||||
"""Cut off the text as soon as any stop words occur."""
|
||||
return re.split("|".join(stop), text, maxsplit=1)[0]
|
||||
|
||||
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
|
||||
"""
|
||||
Transform llm result to stream
|
||||
|
||||
42
api/core/model_runtime/model_providers/__base/tts_model.py
Normal file
42
api/core/model_runtime/model_providers/__base/tts_model.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for ttstext model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
def invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, content_text=content_text)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -2,11 +2,12 @@
|
||||
- anthropic
|
||||
- azure_openai
|
||||
- google
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- cohere
|
||||
- bedrock
|
||||
- togetherai
|
||||
- ollama
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- zhipuai
|
||||
- baichuan
|
||||
- spark
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
|
||||
@@ -625,9 +626,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in LLM_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity.entity.model = model
|
||||
ai_model_entity.entity.label.en_US = model
|
||||
ai_model_entity.entity.label.zh_Hans = model
|
||||
return ai_model_entity
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import copy
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -186,9 +187,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in EMBEDDING_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity.entity.model = model
|
||||
ai_model_entity.entity.label.en_US = model
|
||||
ai_model_entity.entity.label.zh_Hans = model
|
||||
return ai_model_entity
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
<svg width="140" height="24" viewBox="0 0 140 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M131.701 17.9999V6.8999H133.876V13.6049L136.531 10.3349H139.141L135.976 13.9949L139.381 17.9999H136.711L133.876 14.5049V17.9999H131.701Z" fill="#252F3E"/>
|
||||
<path d="M129.847 17.6699C129.577 17.8299 129.252 17.9499 128.872 18.0299C128.492 18.1199 128.097 18.1649 127.687 18.1649C126.467 18.1649 125.532 17.8249 124.882 17.1449C124.242 16.4649 123.922 15.4849 123.922 14.2049C123.922 12.9349 124.262 11.9449 124.942 11.2349C125.622 10.5249 126.567 10.1699 127.777 10.1699C128.507 10.1699 129.182 10.3299 129.802 10.6499V12.1049C129.212 11.9349 128.672 11.8499 128.182 11.8499C127.482 11.8499 126.967 12.0299 126.637 12.3899C126.307 12.7399 126.142 13.2999 126.142 14.0699V14.2799C126.142 15.0399 126.302 15.5999 126.622 15.9599C126.952 16.3099 127.457 16.4849 128.137 16.4849C128.627 16.4849 129.197 16.3949 129.847 16.2149V17.6699Z" fill="#252F3E"/>
|
||||
<path d="M118.51 18.2249C117.32 18.2249 116.39 17.8699 115.72 17.1599C115.05 16.4399 114.715 15.4399 114.715 14.1599C114.715 12.8899 115.05 11.8999 115.72 11.1899C116.39 10.4699 117.32 10.1099 118.51 10.1099C119.7 10.1099 120.63 10.4699 121.3 11.1899C121.97 11.8999 122.305 12.8899 122.305 14.1599C122.305 15.4399 121.97 16.4399 121.3 17.1599C120.63 17.8699 119.7 18.2249 118.51 18.2249ZM118.51 16.5449C119.56 16.5449 120.085 15.7499 120.085 14.1599C120.085 12.5799 119.56 11.7899 118.51 11.7899C117.46 11.7899 116.935 12.5799 116.935 14.1599C116.935 15.7499 117.46 16.5449 118.51 16.5449Z" fill="#252F3E"/>
|
||||
<path d="M108.727 17.9998V10.3348H110.527L110.797 11.4748C111.197 11.0348 111.572 10.7248 111.922 10.5448C112.282 10.3548 112.662 10.2598 113.062 10.2598C113.252 10.2598 113.452 10.2748 113.662 10.3048V12.3298C113.382 12.2698 113.072 12.2398 112.732 12.2398C112.082 12.2398 111.477 12.3548 110.917 12.5848V17.9998H108.727Z" fill="#252F3E"/>
|
||||
<path d="M104.417 17.9999L104.237 17.3249C103.617 17.8849 102.882 18.1649 102.032 18.1649C101.402 18.1649 100.847 18.0099 100.367 17.6999C99.8866 17.3799 99.5116 16.9199 99.2416 16.3199C98.9816 15.7199 98.8516 15.0149 98.8516 14.2049C98.8516 12.9649 99.1466 11.9749 99.7366 11.2349C100.327 10.4849 101.107 10.1099 102.077 10.1099C102.867 10.1099 103.552 10.3349 104.132 10.7849V6.8999H106.322V17.9999H104.417ZM102.752 16.5149C103.232 16.5149 103.692 16.3749 104.132 16.0949V12.1349C103.702 11.8849 103.207 11.7599 102.647 11.7599C102.117 11.7599 101.722 11.9599 101.462 12.3599C101.202 12.7499 101.072 13.3449 101.072 14.1449C101.072 14.9449 101.207 15.5399 101.477 15.9299C101.757 16.3199 102.182 16.5149 102.752 16.5149Z" fill="#252F3E"/>
|
||||
<path d="M92.4625 14.6999C92.5025 15.3599 92.7025 15.8399 93.0625 16.1399C93.4225 16.4299 93.9875 16.5749 94.7575 16.5749C95.4275 16.5749 96.2075 16.4499 97.0975 16.1999V17.6549C96.7475 17.8349 96.3275 17.9749 95.8375 18.0749C95.3575 18.1749 94.8575 18.2249 94.3375 18.2249C93.0675 18.2249 92.0975 17.8799 91.4275 17.1899C90.7675 16.4999 90.4375 15.4899 90.4375 14.1599C90.4375 12.8799 90.7675 11.8849 91.4275 11.1749C92.0875 10.4649 93.0025 10.1099 94.1725 10.1099C95.1625 10.1099 95.9225 10.3849 96.4525 10.9349C96.9925 11.4749 97.2625 12.2499 97.2625 13.2599C97.2625 13.4799 97.2475 13.7299 97.2175 14.0099C97.1875 14.2899 97.1525 14.5199 97.1125 14.6999H92.4625ZM94.0975 11.6249C93.6075 11.6249 93.2175 11.7749 92.9275 12.0749C92.6475 12.3649 92.4875 12.7899 92.4475 13.3499H95.3875V13.0949C95.3875 12.1149 94.9575 11.6249 94.0975 11.6249Z" fill="#252F3E"/>
|
||||
<path d="M81.1992 18V7.60498H84.9342C85.9342 7.60498 86.7392 7.85998 87.3492 8.36998C87.9692 8.87998 88.2792 9.54498 88.2792 10.365C88.2792 10.875 88.1592 11.315 87.9192 11.685C87.6892 12.045 87.3442 12.325 86.8842 12.525C87.5242 12.715 88.0092 13.03 88.3392 13.47C88.6792 13.9 88.8492 14.43 88.8492 15.06C88.8492 15.96 88.5142 16.675 87.8442 17.205C87.1742 17.735 86.2742 18 85.1442 18H81.1992ZM83.3292 13.47V16.395H85.0992C86.1192 16.395 86.6292 15.915 86.6292 14.955C86.6292 13.965 86.0842 13.47 84.9942 13.47H83.3292ZM83.3292 9.20998V11.94H84.6342C85.6042 11.94 86.0892 11.49 86.0892 10.59C86.0892 9.66998 85.6442 9.20998 84.7542 9.20998H83.3292Z" fill="#252F3E"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M62.0002 20.4425L58.6644 21.5548L57.3636 20.6872L58.7799 20.2142L58.3454 18.9107L55.9143 19.7206L55.1251 19.1953V15.4374C55.1251 15.1775 54.9779 14.9396 54.7456 14.8227L52.375 13.6375V10.3621L54.4376 9.33087L56.5001 10.3621V12.6873C56.5001 12.9486 56.6472 13.1864 56.8796 13.3033L59.6297 14.6783L60.2457 13.4477L57.8751 12.2624V10.3621L60.2457 9.17824C60.4781 9.06136 60.6252 8.82349 60.6252 8.56223V6.49969H59.2502V8.13735L57.1876 9.16862L55.1251 8.13735V4.80566L56.5001 3.88852V6.49969H57.8751V2.97275L58.6644 2.44612L62.0002 3.55851V20.4425ZM69.5628 18.8749C69.941 18.8749 70.2504 19.1829 70.2504 19.5624C70.2504 19.9419 69.941 20.2499 69.5628 20.2499C69.1847 20.2499 68.8753 19.9419 68.8753 19.5624C68.8753 19.1829 69.1847 18.8749 69.5628 18.8749ZM68.1878 3.74964C68.566 3.74964 68.8753 4.05765 68.8753 4.43715C68.8753 4.81666 68.566 5.12467 68.1878 5.12467C67.8097 5.12467 67.5003 4.81666 67.5003 4.43715C67.5003 4.05765 67.8097 3.74964 68.1878 3.74964ZM70.9379 11.9998C71.316 11.9998 71.6254 12.3078 71.6254 12.6873C71.6254 13.0668 71.316 13.3748 70.9379 13.3748C70.5597 13.3748 70.2504 13.0668 70.2504 12.6873C70.2504 12.3078 70.5597 11.9998 70.9379 11.9998ZM69.0018 13.3748C69.2865 14.1737 70.0427 14.7498 70.9379 14.7498C72.075 14.7498 73.0004 13.8258 73.0004 12.6873C73.0004 11.5502 72.075 10.6248 70.9379 10.6248C70.0427 10.6248 69.2865 11.2023 69.0018 11.9998H63.3752V9.24974H68.1878C68.5673 9.24974 68.8753 8.94311 68.8753 8.56223V6.37319C69.6742 6.08856 70.2504 5.3323 70.2504 4.43715C70.2504 3.30001 69.325 2.37462 68.1878 2.37462C67.0507 2.37462 66.1253 3.30001 66.1253 4.43715C66.1253 5.3323 66.7014 6.08856 67.5003 6.37319V7.87472H63.3752V3.06213C63.3752 2.7665 63.1855 2.50387 62.905 2.41037L58.7799 1.03534C58.5778 0.967964 58.3578 0.998214 58.1818 1.11509L54.0567 3.86514C53.8656 3.99302 53.7501 4.20752 53.7501 4.43715V8.13735L51.3795 9.32262C51.1471 9.4395 51 9.67738 51 9.93726V14.0623C51 14.3236 51.1471 14.5615 51.3795 14.6783L53.7501 15.8622V19.5624C53.7501 19.7921 53.8656 20.0079 54.0567 20.1344L58.1818 22.8845C58.2959 22.9615 58.4279 23 58.5626 23C58.6355 23 58.7084 22.989 58.7799 22.9642L62.905 21.5892C63.1855 21.4971 63.3752 21.2345 63.3752 20.9375V17.4999H66.5282L67.7011 18.6742L67.7189 18.6563C67.5842 18.9313 67.5003 19.2366 67.5003 19.5624C67.5003 20.6996 68.4257 21.625 69.5628 21.625C70.7 21.625 71.6254 20.6996 71.6254 19.5624C71.6254 18.4253 70.7 17.4999 69.5628 17.4999C69.2356 17.4999 68.9303 17.5838 68.6567 17.7199L68.6746 17.702L67.2996 16.327C67.1703 16.1977 66.9957 16.1249 66.8128 16.1249H63.3752V13.3748H69.0018Z" fill="#252F3E"/>
|
||||
<line x1="43.25" y1="4" x2="43.25" y2="20" stroke="black" stroke-opacity="0.08" stroke-width="0.5"/>
|
||||
<path d="M9.89554 9.62679C9.89554 10.0589 9.94226 10.4093 10.024 10.6663C10.1175 10.9232 10.2342 11.2035 10.3978 11.5072C10.4562 11.6006 10.4795 11.6941 10.4795 11.7758C10.4795 11.8926 10.4094 12.0094 10.2576 12.1262L9.52179 12.6168C9.41667 12.6869 9.31156 12.7219 9.21812 12.7219C9.10132 12.7219 8.98453 12.6635 8.86773 12.5584C8.70422 12.3832 8.56406 12.1963 8.44726 12.0094C8.33047 11.8109 8.21367 11.589 8.0852 11.3203C7.17419 12.3949 6.02958 12.9321 4.65139 12.9321C3.6703 12.9321 2.88777 12.6518 2.31546 12.0912C1.74316 11.5306 1.45117 10.7831 1.45117 9.84871C1.45117 8.85594 1.80156 8.05004 2.51402 7.4427C3.22647 6.83536 4.17252 6.53169 5.37552 6.53169C5.77263 6.53169 6.18142 6.56673 6.61356 6.62513C7.04571 6.68353 7.48954 6.77697 7.95672 6.88208V6.02947C7.95672 5.14182 7.76985 4.5228 7.40778 4.16073C7.03403 3.79866 6.40333 3.62347 5.504 3.62347C5.09521 3.62347 4.67475 3.67019 4.2426 3.7753C3.81046 3.88042 3.38999 4.00889 2.9812 4.17241C2.79433 4.25417 2.65417 4.30089 2.57242 4.32424C2.49066 4.3476 2.43226 4.35928 2.38554 4.35928C2.22203 4.35928 2.14027 4.24249 2.14027 3.99722V3.42491C2.14027 3.23804 2.16363 3.09788 2.22203 3.01613C2.28042 2.93437 2.38554 2.85261 2.54906 2.77085C2.95784 2.56062 3.44839 2.38543 4.02069 2.24527C4.59299 2.09344 5.20033 2.02336 5.84271 2.02336C7.23258 2.02336 8.24871 2.33871 8.90277 2.96941C9.54515 3.60011 9.87218 4.55784 9.87218 5.8426V9.62679H9.89554ZM5.15361 11.4021C5.53904 11.4021 5.93615 11.332 6.35661 11.1919C6.77708 11.0517 7.15083 10.7948 7.46618 10.4444C7.65305 10.2225 7.79321 9.97718 7.86328 9.69687C7.93336 9.41656 7.98008 9.07785 7.98008 8.68074V8.1902C7.64137 8.10844 7.2793 8.03836 6.90555 7.99165C6.53181 7.94493 6.16974 7.92157 5.80767 7.92157C5.02514 7.92157 4.45283 8.0734 4.06741 8.38875C3.68198 8.7041 3.49511 9.14793 3.49511 9.73191C3.49511 10.2809 3.63526 10.6896 3.92725 10.9699C4.20756 11.2619 4.61635 11.4021 5.15361 11.4021ZM14.5323 12.6635C14.3221 12.6635 14.182 12.6285 14.0885 12.5467C13.9951 12.4766 13.9133 12.3131 13.8432 12.0912L11.0985 3.06285C11.0285 2.82925 10.9934 2.67742 10.9934 2.59566C10.9934 2.40879 11.0869 2.30367 11.2737 2.30367H12.4183C12.6402 2.30367 12.7921 2.33871 12.8738 2.42047C12.9673 2.49054 13.0374 2.65406 13.1074 2.87597L15.0696 10.6079L16.8916 2.87597C16.95 2.64238 17.0201 2.49054 17.1135 2.42047C17.207 2.35039 17.3705 2.30367 17.5807 2.30367H18.5151C18.737 2.30367 18.8888 2.33871 18.9823 2.42047C19.0757 2.49054 19.1575 2.65406 19.2042 2.87597L21.0496 10.7013L23.0701 2.87597C23.1402 2.64238 23.222 2.49054 23.3037 2.42047C23.3972 2.35039 23.549 2.30367 23.7592 2.30367H24.8454C25.0323 2.30367 25.1374 2.39711 25.1374 2.59566C25.1374 2.65406 25.1258 2.71246 25.1141 2.78253C25.1024 2.85261 25.079 2.94605 25.0323 3.07453L22.2175 12.1029C22.1475 12.3365 22.0657 12.4883 21.9723 12.5584C21.8788 12.6285 21.727 12.6752 21.5284 12.6752H20.524C20.3021 12.6752 20.1502 12.6401 20.0568 12.5584C19.9634 12.4766 19.8816 12.3248 19.8349 12.0912L18.0246 4.55784L16.2259 12.0795C16.1675 12.3131 16.0974 12.4649 16.004 12.5467C15.9105 12.6285 15.747 12.6635 15.5368 12.6635H14.5323ZM29.5407 12.9788C28.9333 12.9788 28.326 12.9088 27.742 12.7686C27.158 12.6285 26.7025 12.4766 26.3988 12.3014C26.212 12.1963 26.0835 12.0795 26.0368 11.9744C25.9901 11.8693 25.9667 11.7525 25.9667 11.6474V11.0517C25.9667 10.8064 26.0601 10.6896 26.2353 10.6896C26.3054 10.6896 26.3755 10.7013 26.4456 10.7247C26.5156 10.748 26.6208 10.7948 26.7375 10.8415C27.1347 11.0167 27.5668 11.1568 28.0223 11.2503C28.4895 11.3437 28.945 11.3904 29.4122 11.3904C30.148 11.3904 30.7203 11.2619 31.1174 11.005C31.5145 10.748 31.7247 10.3743 31.7247 9.89542C31.7247 9.56839 31.6196 9.29976 31.4094 9.07785C31.1992 8.85594 30.8021 8.65738 30.2298 8.47051L28.5362 7.94493C27.6836 7.6763 27.0529 7.27919 26.6675 6.75361C26.282 6.2397 26.0835 5.6674 26.0835 5.06006C26.0835 4.56952 26.1886 4.13737 26.3988 3.76362C26.6091 3.38987 26.8894 3.06285 27.2398 2.80589C27.5902 2.53726 27.9873 2.33871 28.4545 2.19855C28.9216 2.0584 29.4122 2 29.9261 2C30.183 2 30.4517 2.01168 30.7086 2.04672C30.9773 2.08176 31.2225 2.12848 31.4678 2.17519C31.7014 2.23359 31.9233 2.29199 32.1335 2.36207C32.3438 2.43215 32.5073 2.50222 32.6241 2.5723C32.7876 2.66574 32.9044 2.75918 32.9745 2.86429C33.0445 2.95773 33.0796 3.0862 33.0796 3.24972V3.79866C33.0796 4.04393 32.9861 4.17241 32.811 4.17241C32.7175 4.17241 32.5657 4.12569 32.3671 4.03225C31.7014 3.72858 30.9539 3.57675 30.1246 3.57675C29.4589 3.57675 28.9333 3.68187 28.5712 3.90378C28.2092 4.12569 28.0223 4.4644 28.0223 4.94326C28.0223 5.27029 28.1391 5.5506 28.3727 5.77252C28.6063 5.99443 29.0384 6.21634 29.6575 6.4149L31.316 6.94048C32.1569 7.20911 32.7642 7.58286 33.1263 8.06172C33.4884 8.54059 33.6636 9.08953 33.6636 9.69687C33.6636 10.1991 33.5584 10.6546 33.3599 11.0517C33.1497 11.4488 32.8693 11.7992 32.5073 12.0795C32.1452 12.3715 31.7131 12.5817 31.2108 12.7336C30.6853 12.8971 30.1363 12.9788 29.5407 12.9788Z" fill="#252F3E"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M31.749 18.6553C27.9064 21.4934 22.3235 23.0001 17.5232 23.0001C10.7957 23.0001 4.73399 20.5123 0.155575 16.3778C-0.206494 16.0507 0.120536 15.6069 0.552682 15.8639C5.50484 18.737 11.6133 20.4773 17.932 20.4773C22.195 20.4773 26.8786 19.5896 31.1883 17.7676C31.8307 17.4756 32.3797 18.1881 31.749 18.6553Z" fill="#FF9900"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M33.3507 16.833C32.8601 16.2023 30.1037 16.5293 28.854 16.6811C28.4803 16.7278 28.4219 16.4008 28.7606 16.1555C30.9564 14.6138 34.5654 15.0577 34.9858 15.5716C35.4063 16.0971 34.869 19.7062 32.8134 21.4347C32.4981 21.7034 32.1944 21.5632 32.3345 21.2128C32.8017 20.0565 33.8412 17.452 33.3507 16.833Z" fill="#FF9900"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 12 KiB |
@@ -0,0 +1,15 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_16762_59518)">
|
||||
<path d="M12.6667 0H3.33333C1.49238 0 0 1.49238 0 3.33333V12.6667C0 14.5076 1.49238 16 3.33333 16H12.6667C14.5076 16 16 14.5076 16 12.6667V3.33333C16 1.49238 14.5076 0 12.6667 0Z" fill="url(#paint0_linear_16762_59518)"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.99984 12.093L6.3825 12.6323L5.75184 12.2116L6.4385 11.9823L6.22784 11.3503L5.04917 11.743L4.6665 11.4883V9.66631C4.6665 9.54031 4.59517 9.42497 4.4825 9.3683L3.33317 8.79364V7.20564L4.33317 6.70564L5.33317 7.20564V8.33297C5.33317 8.45964 5.4045 8.57497 5.51717 8.63164L6.8505 9.29831L7.14917 8.70164L5.99984 8.12697V7.20564L7.14917 6.63164C7.26184 6.57497 7.33317 6.45964 7.33317 6.33297V5.33297H6.6665V6.12697L5.6665 6.62697L4.6665 6.12697V4.51164L5.33317 4.06697V5.33297H5.99984V3.62297L6.3825 3.36764L7.99984 3.90697V12.093ZM11.6665 11.333C11.8498 11.333 11.9998 11.4823 11.9998 11.6663C11.9998 11.8503 11.8498 11.9996 11.6665 11.9996C11.4832 11.9996 11.3332 11.8503 11.3332 11.6663C11.3332 11.4823 11.4832 11.333 11.6665 11.333ZM10.9998 3.99964C11.1832 3.99964 11.3332 4.14897 11.3332 4.33297C11.3332 4.51697 11.1832 4.6663 10.9998 4.6663C10.8165 4.6663 10.6665 4.51697 10.6665 4.33297C10.6665 4.14897 10.8165 3.99964 10.9998 3.99964ZM12.3332 7.99964C12.5165 7.99964 12.6665 8.14897 12.6665 8.33297C12.6665 8.51697 12.5165 8.66631 12.3332 8.66631C12.1498 8.66631 11.9998 8.51697 11.9998 8.33297C11.9998 8.14897 12.1498 7.99964 12.3332 7.99964ZM11.3945 8.66631C11.5325 9.05364 11.8992 9.33297 12.3332 9.33297C12.8845 9.33297 13.3332 8.88497 13.3332 8.33297C13.3332 7.78164 12.8845 7.33297 12.3332 7.33297C11.8992 7.33297 11.5325 7.61297 11.3945 7.99964H8.6665V6.66631H10.9998C11.1838 6.66631 11.3332 6.51764 11.3332 6.33297V5.27164C11.7205 5.13364 11.9998 4.76697 11.9998 4.33297C11.9998 3.78164 11.5512 3.33297 10.9998 3.33297C10.4485 3.33297 9.99984 3.78164 9.99984 4.33297C9.99984 4.76697 10.2792 5.13364 10.6665 5.27164V5.99964H8.6665V3.6663C8.6665 3.52297 8.5745 3.39564 8.4385 3.3503L6.4385 2.68364C6.3405 2.65097 6.23384 2.66564 6.1485 2.7223L4.1485 4.05564C4.05584 4.11764 3.99984 4.22164 3.99984 4.33297V6.12697L2.8505 6.70164C2.73784 6.75831 2.6665 6.87364 2.6665 6.99964V8.99964C2.6665 9.12631 2.73784 9.24164 2.8505 9.29831L3.99984 9.87231V11.6663C3.99984 11.7776 4.05584 11.8823 4.1485 11.9436L6.1485 13.277C6.20384 13.3143 6.26784 13.333 6.33317 13.333C6.3685 13.333 6.40384 13.3276 6.4385 13.3156L8.4385 12.649C8.5745 12.6043 8.6665 12.477 8.6665 12.333V10.6663H10.1952L10.7638 11.2356L10.7725 11.227C10.7072 11.3603 10.6665 11.5083 10.6665 11.6663C10.6665 12.2176 11.1152 12.6663 11.6665 12.6663C12.2178 12.6663 12.6665 12.2176 12.6665 11.6663C12.6665 11.115 12.2178 10.6663 11.6665 10.6663C11.5078 10.6663 11.3598 10.707 11.2272 10.773L11.2358 10.7643L10.5692 10.0976C10.5065 10.035 10.4218 9.99964 10.3332 9.99964H8.6665V8.66631H11.3945Z" fill="white"/>
|
||||
</g>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_16762_59518" x1="0" y1="1600" x2="1600" y2="0" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#055F4E"/>
|
||||
<stop offset="1" stop-color="#56C0A7"/>
|
||||
</linearGradient>
|
||||
<clipPath id="clip0_16762_59518">
|
||||
<rect width="16" height="16" fill="white"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.2 KiB |
30
api/core/model_runtime/model_providers/bedrock/bedrock.py
Normal file
30
api/core/model_runtime/model_providers/bedrock/bedrock.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BedrockProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `gemini-pro` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='amazon.titan-text-lite-v1',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
71
api/core/model_runtime/model_providers/bedrock/bedrock.yaml
Normal file
71
api/core/model_runtime/model_providers/bedrock/bedrock.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
provider: bedrock
|
||||
label:
|
||||
en_US: AWS
|
||||
description:
|
||||
en_US: AWS Bedrock's models.
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#FCFDFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your Access Key and Secret Access Key from AWS Console
|
||||
url:
|
||||
en_US: https://console.aws.amazon.com/
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: aws_access_key_id
|
||||
required: true
|
||||
label:
|
||||
en_US: Access Key
|
||||
zh_Hans: Access Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Access Key
|
||||
zh_Hans: 在此输入您的 Access Key
|
||||
- variable: aws_secret_access_key
|
||||
required: true
|
||||
label:
|
||||
en_US: Secret Access Key
|
||||
zh_Hans: Secret Access Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Secret Access Key
|
||||
zh_Hans: 在此输入您的 Secret Access Key
|
||||
- variable: aws_region
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 地区
|
||||
type: select
|
||||
default: us-east-1
|
||||
options:
|
||||
- value: us-east-1
|
||||
label:
|
||||
en_US: US East (N. Virginia)
|
||||
zh_Hans: US East (N. Virginia)
|
||||
- value: us-west-2
|
||||
label:
|
||||
en_US: US West (Oregon)
|
||||
zh_Hans: US West (Oregon)
|
||||
- value: ap-southeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Singapore)
|
||||
zh_Hans: Asia Pacific (Singapore)
|
||||
- value: ap-northeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Tokyo)
|
||||
zh_Hans: Asia Pacific (Tokyo)
|
||||
- value: eu-central-1
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: Europe (Frankfurt)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
@@ -0,0 +1,10 @@
|
||||
- amazon.titan-text-express-v1
|
||||
- amazon.titan-text-lite-v1
|
||||
- anthropic.claude-instant-v1
|
||||
- anthropic.claude-v1
|
||||
- anthropic.claude-v2
|
||||
- anthropic.claude-v2:1
|
||||
- cohere.command-light-text-v14
|
||||
- cohere.command-text-v14
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user