mirror of
https://github.com/langgenius/dify.git
synced 2026-01-09 15:54:13 +00:00
Compare commits
174 Commits
0.5.8
...
0.5.11-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d14ea2ecaa | ||
|
|
a94d86da6d | ||
|
|
5e591fc1b7 | ||
|
|
32e83e00e4 | ||
|
|
132269618d | ||
|
|
84d118de07 | ||
|
|
1716ac562c | ||
|
|
e215aae39a | ||
|
|
12782cad4d | ||
|
|
fc5ed17fe9 | ||
|
|
94d04934b3 | ||
|
|
1387f9b23e | ||
|
|
6817eab5f1 | ||
|
|
218f591a5d | ||
|
|
17af0de7b6 | ||
|
|
9d962053a2 | ||
|
|
59909b5ca7 | ||
|
|
a6cd0f0e73 | ||
|
|
2c43393bf1 | ||
|
|
669c8c3cca | ||
|
|
b0b0cc045f | ||
|
|
20d16d7b31 | ||
|
|
714722bb2d | ||
|
|
830495a607 | ||
|
|
41a4593b6d | ||
|
|
08b727833e | ||
|
|
c8b82b9d08 | ||
|
|
5becb4c43a | ||
|
|
13694293e3 | ||
|
|
815beac356 | ||
|
|
5e60204832 | ||
|
|
d2624b13a0 | ||
|
|
61f5de9662 | ||
|
|
40dbf30784 | ||
|
|
afd77c4745 | ||
|
|
d70bd4aaa4 | ||
|
|
8e05261588 | ||
|
|
a676d4387c | ||
|
|
08a5afcf9f | ||
|
|
eeaa3c1643 | ||
|
|
7c8c233cf4 | ||
|
|
129a9850eb | ||
|
|
1f98a4fff3 | ||
|
|
58e4702b14 | ||
|
|
c60749678b | ||
|
|
d5214e4644 | ||
|
|
52804ca6d1 | ||
|
|
4fb9606361 | ||
|
|
c534d95972 | ||
|
|
46ccfda493 | ||
|
|
6dc62334d6 | ||
|
|
c7d003d551 | ||
|
|
cc754122fc | ||
|
|
240a94182e | ||
|
|
16af509c46 | ||
|
|
86e474fff1 | ||
|
|
9a3d5729bb | ||
|
|
5a1c29fd8c | ||
|
|
180775a0ec | ||
|
|
d018e279f8 | ||
|
|
11636bc7c7 | ||
|
|
518c1ceb94 | ||
|
|
696efe494e | ||
|
|
4419d357c4 | ||
|
|
fbbba6db92 | ||
|
|
53d428907b | ||
|
|
8133ba16b1 | ||
|
|
e9aa0e89d3 | ||
|
|
7e3c59e53e | ||
|
|
f6314f8e73 | ||
|
|
3bcfd84fba | ||
|
|
7c0ae76cd0 | ||
|
|
2dee8a25d5 | ||
|
|
507aa6d949 | ||
|
|
59f173f2e6 | ||
|
|
c3790c239c | ||
|
|
45e51e7730 | ||
|
|
4834eae887 | ||
|
|
01108e6172 | ||
|
|
95b74c211d | ||
|
|
cb79a90031 | ||
|
|
4502436c47 | ||
|
|
c3d0cf940c | ||
|
|
e7343cc67c | ||
|
|
83145486b0 | ||
|
|
6fd1795d25 | ||
|
|
f770232b63 | ||
|
|
a8e694c235 | ||
|
|
15a6d94953 | ||
|
|
056331981e | ||
|
|
cef16862da | ||
|
|
8a4015722d | ||
|
|
156345cb4b | ||
|
|
f29280ba5c | ||
|
|
742be06ea9 | ||
|
|
af98954fc1 | ||
|
|
4d63770189 | ||
|
|
bbea3a6b84 | ||
|
|
19d3a56194 | ||
|
|
5cab2b711f | ||
|
|
1e5455e266 | ||
|
|
4fe585acc2 | ||
|
|
e52448b84b | ||
|
|
1f92b55f58 | ||
|
|
8b15b742ad | ||
|
|
849dc0560b | ||
|
|
a026c5fd08 | ||
|
|
fd7aade26b | ||
|
|
510f8ede10 | ||
|
|
8f9125b08a | ||
|
|
e5e97c0a0a | ||
|
|
870ca713df | ||
|
|
6854a3fd26 | ||
|
|
620360d41a | ||
|
|
20bd49285b | ||
|
|
6bd2730317 | ||
|
|
f734cca337 | ||
|
|
ce5b19d011 | ||
|
|
f82a64d149 | ||
|
|
f49b1afd6c | ||
|
|
796c5626a7 | ||
|
|
e54c9cd401 | ||
|
|
f8951d7f57 | ||
|
|
6454e1d644 | ||
|
|
e184c8cb42 | ||
|
|
fdd211e399 | ||
|
|
7001e21e7d | ||
|
|
82d0732c12 | ||
|
|
53cd125780 | ||
|
|
3c91f9b5ab | ||
|
|
f073dca22a | ||
|
|
8b1e35d7dc | ||
|
|
b75d8ca621 | ||
|
|
9beefd7d5a | ||
|
|
88145efa97 | ||
|
|
bdc13f9238 | ||
|
|
ce58f0607b | ||
|
|
bbc0d330a9 | ||
|
|
60e7e17c86 | ||
|
|
237bb8514e | ||
|
|
bd26c933d2 | ||
|
|
b6b58da2d2 | ||
|
|
40c646cf7a | ||
|
|
3231a8c51c | ||
|
|
4170d6a491 | ||
|
|
0b50c525cf | ||
|
|
8ba38e8e74 | ||
|
|
b163545771 | ||
|
|
c0b82f8e58 | ||
|
|
b75ff5fa03 | ||
|
|
9440d7fe88 | ||
|
|
24809fce07 | ||
|
|
9819ad347f | ||
|
|
8fe83750b7 | ||
|
|
1809f05904 | ||
|
|
0ac250a035 | ||
|
|
405a00bb2c | ||
|
|
3a3ca8e6a9 | ||
|
|
27e678480e | ||
|
|
7052565380 | ||
|
|
31070ffbca | ||
|
|
7f3dec7bee | ||
|
|
b1e0db4944 | ||
|
|
c439952a41 | ||
|
|
2f28afebb6 | ||
|
|
fa7ba30ba3 | ||
|
|
1cf5f510ed | ||
|
|
526c874caa | ||
|
|
f88f744097 | ||
|
|
95733796f0 | ||
|
|
552f319b9d | ||
|
|
38e5952417 | ||
|
|
7f891939f1 | ||
|
|
69a5ce1e31 |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -12,6 +12,8 @@ Please delete options that are not relevant.
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [ ] Improvement, including but not limited to code refactoring, performance optimization, and UI/UX improvement
|
||||
- [ ] Dependency upgrade
|
||||
|
||||
# How Has This Been Tested?
|
||||
|
||||
|
||||
@@ -1,17 +1,32 @@
|
||||
name: Build and Push API Image
|
||||
name: Build and Push API & Web
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'deploy/dev'
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
release:
|
||||
types: [ published ]
|
||||
types: [published]
|
||||
|
||||
env:
|
||||
DOCKERHUB_USER: ${{ secrets.DOCKERHUB_USER }}
|
||||
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
DIFY_WEB_IMAGE_NAME: ${{ vars.DIFY_WEB_IMAGE_NAME || 'langgenius/dify-web' }}
|
||||
DIFY_API_IMAGE_NAME: ${{ vars.DIFY_API_IMAGE_NAME || 'langgenius/dify-api' }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "web"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
context: "web"
|
||||
- service_name: "api"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
context: "api"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
@@ -22,14 +37,14 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USER }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: langgenius/dify-api
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
@@ -39,22 +54,11 @@ jobs:
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: "{{defaultContext}}:api"
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }}
|
||||
build-args: |
|
||||
COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Deploy to server
|
||||
if: github.ref == 'refs/heads/deploy/dev'
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ secrets.SSH_SCRIPT }}
|
||||
60
.github/workflows/build-web-image.yml
vendored
60
.github/workflows/build-web-image.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Build and Push WEB Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'deploy/dev'
|
||||
release:
|
||||
types: [ published ]
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USER }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: langgenius/dify-web
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: "{{defaultContext}}:web"
|
||||
platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }}
|
||||
build-args: |
|
||||
COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Deploy to server
|
||||
if: github.ref == 'refs/heads/deploy/dev'
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ secrets.SSH_SCRIPT }}
|
||||
24
.github/workflows/deploy-dev.yml
vendored
Normal file
24
.github/workflows/deploy-dev.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Deploy Dev
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -145,10 +145,14 @@ docker/volumes/db/data/*
|
||||
docker/volumes/redis/data/*
|
||||
docker/volumes/weaviate/*
|
||||
docker/volumes/qdrant/*
|
||||
docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
!.vscode/launch.json
|
||||
pyrightconfig.json
|
||||
|
||||
@@ -155,4 +155,4 @@ And that's it! Once your PR is merged, you will be featured as a contributor in
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/AhzKf7dNgk) for a quick chat.
|
||||
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
@@ -152,4 +152,4 @@ Dify的后端使用Python编写,使用[Flask](https://flask.palletsprojects.co
|
||||
|
||||
## 获取帮助
|
||||
|
||||
如果你在贡献过程中遇到困难或者有任何问题,可以通过相关的 GitHub 问题提出你的疑问,或者加入我们的 [Discord](https://discord.gg/AhzKf7dNgk) 进行快速交流。
|
||||
如果你在贡献过程中遇到困难或者有任何问题,可以通过相关的 GitHub 问题提出你的疑问,或者加入我们的 [Discord](https://discord.gg/8Tpq4AcN9c) 进行快速交流。
|
||||
|
||||
43
Makefile
Normal file
43
Makefile
Normal file
@@ -0,0 +1,43 @@
|
||||
# Variables
|
||||
DOCKER_REGISTRY=langgenius
|
||||
WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
||||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||
VERSION=latest
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
docker build -t $(WEB_IMAGE):$(VERSION) ./web
|
||||
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
|
||||
|
||||
build-api:
|
||||
@echo "Building API Docker image: $(API_IMAGE):$(VERSION)..."
|
||||
docker build -t $(API_IMAGE):$(VERSION) ./api
|
||||
@echo "API Docker image built successfully: $(API_IMAGE):$(VERSION)"
|
||||
|
||||
# Push Docker images
|
||||
push-web:
|
||||
@echo "Pushing web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
docker push $(WEB_IMAGE):$(VERSION)
|
||||
@echo "Web Docker image pushed successfully: $(WEB_IMAGE):$(VERSION)"
|
||||
|
||||
push-api:
|
||||
@echo "Pushing API Docker image: $(API_IMAGE):$(VERSION)..."
|
||||
docker push $(API_IMAGE):$(VERSION)
|
||||
@echo "API Docker image pushed successfully: $(API_IMAGE):$(VERSION)"
|
||||
|
||||
# Build all images
|
||||
build-all: build-web build-api
|
||||
|
||||
# Push all images
|
||||
push-all: push-web push-api
|
||||
|
||||
build-push-api: build-api push-api
|
||||
build-push-web: build-web push-web
|
||||
|
||||
# Build and push all images
|
||||
build-push-all: build-all push-all
|
||||
@echo "All Docker images have been built and pushed."
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
|
||||
29
README.md
29
README.md
@@ -22,19 +22,8 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
|
||||
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
|
||||
</a>
|
||||
<ul align="center" style="text-decoration: none; list-style: none;">
|
||||
<li> US EST: 09:00 (9:00 AM)</li>
|
||||
<li> CET: 15:00 (3:00 PM)</li>
|
||||
<li> CST: 22:00 (10:00 PM)</li>
|
||||
</ul>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
<a href="https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6" target="_blank">
|
||||
📌 Check out Dify Premium on AWS and deploy it to your own AWS VPC with one-click.
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -48,6 +37,9 @@
|
||||
|
||||
You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabilities of the self-deployed version, and includes 200 free requests to OpenAI GPT-3.5.
|
||||
|
||||
### Looking to purchase via AWS?
|
||||
Check out [Dify Premium on AWS](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click.
|
||||
|
||||
## Dify vs. LangChain vs. Assistants API
|
||||
|
||||
| Feature | Dify.AI | Assistants API | LangChain |
|
||||
@@ -108,10 +100,12 @@ docker compose up -d
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization installation process.
|
||||
|
||||
### Helm Chart
|
||||
#### Deploy with Helm Chart
|
||||
|
||||
Big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
|
||||
You can go to https://github.com/BorisPolonsky/dify-helm for deployment information.
|
||||
[Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
### Configuration
|
||||
|
||||
@@ -128,6 +122,7 @@ For those who'd like to contribute code, see our [Contribution Guide](https://gi
|
||||
|
||||
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
|
||||
|
||||
|
||||
### Contributors
|
||||
|
||||
<a href="https://github.com/langgenius/dify/graphs/contributors">
|
||||
@@ -136,7 +131,7 @@ At the same time, please consider supporting Dify by sharing it on social media
|
||||
|
||||
### Translations
|
||||
|
||||
We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README_EN.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/AhzKf7dNgk).
|
||||
We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
## Community & Support
|
||||
|
||||
|
||||
@@ -94,10 +94,12 @@ docker compose up -d
|
||||
|
||||
运行后,可以在浏览器上访问 [http://localhost/install](http://localhost/install) 进入 Dify 控制台并开始初始化安装操作。
|
||||
|
||||
### Helm Chart
|
||||
#### 使用 Helm Chart 部署
|
||||
|
||||
非常感谢 @BorisPolonsky 为我们提供了一个 [Helm Chart](https://helm.sh/) 版本,可以在 Kubernetes 上部署 Dify。
|
||||
您可以前往 https://github.com/BorisPolonsky/dify-helm 来获取部署信息。
|
||||
使用 [Helm Chart](https://helm.sh/) 版本,可以在 Kubernetes 上部署 Dify。
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
### 配置
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3
|
||||
# storage type: local, s3, azure-blob
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
|
||||
@@ -47,6 +47,11 @@ S3_BUCKET_NAME=your-bucket-name
|
||||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
AZURE_BLOB_CONTAINER_NAME=yout-container-name
|
||||
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
@@ -132,3 +137,4 @@ SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
@@ -5,7 +5,7 @@
|
||||
1. Start the docker-compose stack
|
||||
|
||||
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
|
||||
```bash
|
||||
cd ../docker
|
||||
docker-compose -f docker-compose.middleware.yaml -p dify up -d
|
||||
@@ -15,7 +15,7 @@
|
||||
3. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
```bash
|
||||
openssl rand -base64 42
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
3.5 If you use annaconda, create a new environment and activate it
|
||||
```bash
|
||||
@@ -46,7 +46,7 @@
|
||||
```
|
||||
pip install -r requirements.txt --upgrade --force-reinstall
|
||||
```
|
||||
|
||||
|
||||
6. Start backend:
|
||||
```bash
|
||||
flask run --host 0.0.0.0 --port=5001 --debug
|
||||
|
||||
@@ -26,6 +26,7 @@ from config import CloudEditionConfig, Config
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_login,
|
||||
@@ -96,6 +97,7 @@ def create_app(test_config=None) -> Flask:
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
|
||||
@@ -109,19 +109,20 @@ def reset_encrypt_key_pair():
|
||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
||||
return
|
||||
|
||||
tenant = db.session.query(Tenant).first()
|
||||
if not tenant:
|
||||
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
|
||||
return
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
|
||||
return
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
|
||||
db.session.query(ProviderModel).delete()
|
||||
db.session.commit()
|
||||
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! '
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
click.echo(click.style('Congratulations! '
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
|
||||
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
@@ -254,7 +255,7 @@ def migrate_knowledge_vector_database():
|
||||
for dataset in datasets:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
|
||||
+ f'{create_count} created, ${skipped_count} skipped.')
|
||||
+ f'{create_count} created, {skipped_count} skipped.')
|
||||
try:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
|
||||
@@ -22,6 +22,7 @@ DEFAULTS = {
|
||||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'FILES_URL': '',
|
||||
'S3_ADDRESS_STYLE': 'auto',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
@@ -59,7 +60,9 @@ DEFAULTS = {
|
||||
'CAN_REPLACE_LOGO': 'False',
|
||||
'ETL_TYPE': 'dify',
|
||||
'KEYWORD_STORE': 'jieba',
|
||||
'BATCH_UPLOAD_LIMIT': 20
|
||||
'BATCH_UPLOAD_LIMIT': 20,
|
||||
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
||||
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
||||
}
|
||||
|
||||
|
||||
@@ -90,7 +93,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.8"
|
||||
self.CURRENT_VERSION = "0.5.11-fix1"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -180,6 +183,11 @@ class Config:
|
||||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE')
|
||||
self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
|
||||
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
@@ -293,6 +301,10 @@ class Config:
|
||||
|
||||
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
|
||||
|
||||
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
|
||||
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
|
||||
|
||||
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
|
||||
from models.model import AppModelConfig
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
@@ -16,6 +16,7 @@ language_timezone_mapping = {
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
'vi-VN': 'Asia/Ho_Chi_Minh',
|
||||
}
|
||||
|
||||
|
||||
@@ -79,6 +80,16 @@ user_input_form_template = {
|
||||
}
|
||||
}
|
||||
],
|
||||
"vi-VN": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Nội dung truy vấn",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
demo_model_templates = {
|
||||
@@ -208,7 +219,6 @@ demo_model_templates = {
|
||||
)
|
||||
}
|
||||
],
|
||||
|
||||
'zh-Hans': [
|
||||
{
|
||||
'name': '翻译助手',
|
||||
@@ -335,91 +345,92 @@ demo_model_templates = {
|
||||
)
|
||||
}
|
||||
],
|
||||
'uk-UA': [{
|
||||
"name": "Помічник перекладу",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
|
||||
"mode": "completion",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo-instruct",
|
||||
configs={
|
||||
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
|
||||
"prompt_variables": [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Цільова мова",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"type": "select",
|
||||
"default": "Ukrainian",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"English",
|
||||
"Japanese",
|
||||
"French",
|
||||
"Russian",
|
||||
"German",
|
||||
"Spanish",
|
||||
"Korean",
|
||||
"Italian",
|
||||
],
|
||||
'uk-UA': [
|
||||
{
|
||||
"name": "Помічник перекладу",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
|
||||
"mode": "completion",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo-instruct",
|
||||
configs={
|
||||
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
|
||||
"prompt_variables": [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Цільова мова",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"type": "select",
|
||||
"default": "Ukrainian",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"English",
|
||||
"Japanese",
|
||||
"French",
|
||||
"Russian",
|
||||
"German",
|
||||
"Spanish",
|
||||
"Korean",
|
||||
"Italian",
|
||||
],
|
||||
},
|
||||
],
|
||||
"completion_params": {
|
||||
"max_token": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
],
|
||||
"completion_params": {
|
||||
"max_token": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Цільова мова",
|
||||
"variable": "target_language",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
opening_statement="",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Цільова мова",
|
||||
"variable": "target_language",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
"name": "AI інтерв’юер фронтенду",
|
||||
"icon": "",
|
||||
@@ -460,5 +471,132 @@ demo_model_templates = {
|
||||
),
|
||||
}
|
||||
],
|
||||
|
||||
'vi-VN': [
|
||||
{
|
||||
'name': 'Trợ lý dịch thuật',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'Trình dịch đa ngôn ngữ cung cấp khả năng dịch bằng nhiều ngôn ngữ, dịch thông tin đầu vào của người dùng sang ngôn ngữ họ cần.',
|
||||
'mode': 'completion',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo-instruct',
|
||||
configs={
|
||||
'prompt_template': "Hãy dịch đoạn văn bản sau sang ngôn ngữ {{target_language}}:\n",
|
||||
'prompt_variables': [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Ngôn ngữ đích",
|
||||
"description": "Ngôn ngữ bạn muốn dịch sang.",
|
||||
"type": "select",
|
||||
"default": "Vietnamese",
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
'Vietnamese',
|
||||
]
|
||||
}
|
||||
],
|
||||
'completion_params': {
|
||||
'max_token': 1000,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Hãy dịch đoạn văn bản sau sang {{target_language}}:\n{{query}}\ndịch:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Ngôn ngữ đích",
|
||||
"variable": "target_language",
|
||||
"description": "Ngôn ngữ bạn muốn dịch sang.",
|
||||
"default": "Vietnamese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
'Vietnamese',
|
||||
]
|
||||
}
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
'name': 'Phỏng vấn front-end AI',
|
||||
'icon': '',
|
||||
'icon_background': '',
|
||||
'description': 'Một người phỏng vấn front-end mô phỏng để kiểm tra mức độ kỹ năng phát triển front-end thông qua việc đặt câu hỏi.',
|
||||
'mode': 'chat',
|
||||
'model_config': AppModelConfig(
|
||||
provider='openai',
|
||||
model_id='gpt-3.5-turbo',
|
||||
configs={
|
||||
'introduction': 'Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ',
|
||||
'prompt_template': "Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n",
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 300,
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.9,
|
||||
'presence_penalty': 0.1,
|
||||
'frequency_penalty': 0.1,
|
||||
}
|
||||
},
|
||||
opening_statement='Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1
|
||||
}
|
||||
}),
|
||||
user_input_form=None
|
||||
)
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ from fields.app_fields import (
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
|
||||
def _get_app(app_id, tenant_id):
|
||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||
@@ -236,7 +238,44 @@ class AppApi(Resource):
|
||||
def get(self, app_id):
|
||||
"""Get app detail"""
|
||||
app_id = str(app_id)
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
app: App = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
# get original app model config
|
||||
model_config: AppModelConfig = app.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@@ -7,6 +8,9 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
@@ -38,6 +42,91 @@ class ModelConfigResource(Resource):
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
# get original app model config
|
||||
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == app.app_model_config_id
|
||||
).first()
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
parameter_map = {}
|
||||
masked_parameter_map = {}
|
||||
tool_map = {}
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
parameters = {}
|
||||
masked_parameter = {}
|
||||
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
masked_parameter_map[key] = masked_parameter
|
||||
parameter_map[key] = parameters
|
||||
tool_map[key] = tool_runtime
|
||||
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
|
||||
# get tool
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
if key in tool_map:
|
||||
tool_runtime = tool_map[key]
|
||||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
# override parameters if it equals to masked parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
|
||||
# update app model config
|
||||
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
|
||||
@@ -52,6 +52,12 @@ class RecommendedAppListApi(Resource):
|
||||
RecommendedApp.language == language_prefix
|
||||
).all()
|
||||
|
||||
if len(recommended_apps) == 0:
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.language == languages[0]
|
||||
).all()
|
||||
|
||||
categories = set()
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
recommended_apps_result = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import io
|
||||
|
||||
from flask import send_file
|
||||
from flask import current_app, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -80,8 +80,33 @@ class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
|
||||
icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE'))
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=minetype, max_age=icon_cache_max_age)
|
||||
|
||||
class ToolModelProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
|
||||
|
||||
class ToolModelProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.list_model_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@@ -283,6 +308,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
||||
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
|
||||
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
|
||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
||||
|
||||
@@ -44,7 +44,7 @@ class AudioApi(Resource):
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
end_user=end_user
|
||||
end_user=end_user.get_id()
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -75,7 +75,7 @@ class AudioApi(Resource):
|
||||
|
||||
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
@@ -86,8 +86,8 @@ class TextApi(Resource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=end_user,
|
||||
voice=args['voice'] if args['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
end_user=end_user.get_id(),
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
|
||||
@@ -197,11 +197,11 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
|
||||
parser.add_argument('segment', type=dict, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
SegmentService.segment_create_args_validate(args['segments'], document)
|
||||
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
|
||||
SegmentService.segment_create_args_validate(args['segment'], document)
|
||||
segment = SegmentService.update_segment(args['segment'], segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
|
||||
@@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner):
|
||||
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
db.session.refresh(conversation)
|
||||
db.session.refresh(message)
|
||||
db.session.close()
|
||||
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
||||
|
||||
@@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner):
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
|
||||
@@ -89,6 +89,10 @@ class GenerateTaskPipeline:
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
db.session.refresh(self._conversation)
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
|
||||
if stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
@@ -303,6 +307,7 @@ class GenerateTaskPipeline:
|
||||
.first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
@@ -330,6 +335,8 @@ class GenerateTaskPipeline:
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
@@ -413,6 +420,7 @@ class GenerateTaskPipeline:
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
|
||||
@@ -35,7 +35,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.provider_manager import ProviderManager
|
||||
@@ -195,13 +195,11 @@ class ApplicationManager:
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.remove()
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
@@ -233,8 +231,6 @@ class ApplicationManager:
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||
-> AppOrchestrationConfigEntity:
|
||||
@@ -651,6 +647,7 @@ class ApplicationManager:
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
@@ -689,6 +686,7 @@ class ApplicationManager:
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db.session.refresh(message)
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import enum
|
||||
import importlib.util
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.utils.position_helper import sort_to_dict_by_position_map
|
||||
|
||||
|
||||
class ExtensionModule(enum.Enum):
|
||||
MODERATION = 'moderation'
|
||||
@@ -36,7 +37,8 @@ class Extensible:
|
||||
|
||||
@classmethod
|
||||
def scan_extensions(cls):
|
||||
extensions = {}
|
||||
extensions: list[ModuleExtension] = []
|
||||
position_map = {}
|
||||
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
|
||||
@@ -63,6 +65,7 @@ class Extensible:
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
@@ -96,16 +99,15 @@ class Extensible:
|
||||
with open(json_path, encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions[extension_name] = ModuleExtension(
|
||||
extensions.append(ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
name=extension_name,
|
||||
label=json_data.get('label'),
|
||||
form_schema=json_data.get('form_schema'),
|
||||
builtin=builtin,
|
||||
position=position
|
||||
)
|
||||
))
|
||||
|
||||
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
|
||||
sorted_extensions = OrderedDict(sorted_items)
|
||||
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
|
||||
|
||||
return sorted_extensions
|
||||
|
||||
@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
db.session.close()
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
@@ -144,7 +145,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += "image has been created and sent to user already, you should tell user to check it now."
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
@@ -154,9 +155,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
|
||||
tenant_id=self.application_generate_entity.tenant_id,
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_tool=tool,
|
||||
agent_callback=self.agent_callback
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
@@ -171,33 +172,11 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
}
|
||||
)
|
||||
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_entity.parameters or []
|
||||
user_parameters = tool_entity.get_runtime_parameters() or []
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
@@ -213,59 +192,16 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
@@ -305,6 +241,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
@@ -320,18 +259,17 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
@@ -404,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
created_by=self.user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
@@ -447,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
db.session.add(thought)
|
||||
db.session.commit()
|
||||
db.session.refresh(thought)
|
||||
db.session.close()
|
||||
|
||||
self.agent_thought_count += 1
|
||||
|
||||
@@ -464,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.id == agent_thought.id
|
||||
).first()
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
||||
@@ -514,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
@@ -586,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
).first()
|
||||
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
@@ -613,7 +566,11 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
tools = tools.split(';')
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
logging.warning("tool execution error: {}, tool_input: {}.".format(str(e), agent_thought.tool_input))
|
||||
tool_inputs = { agent_thought.tool: agent_thought.tool_input }
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
@@ -644,4 +601,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
if message.answer:
|
||||
result.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
return result
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
|
||||
@@ -28,6 +28,9 @@ from models.model import Conversation, Message
|
||||
|
||||
|
||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
@@ -42,10 +45,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
# TODO: stop words
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
if app_orchestration_config.model_config.provider not in self._ignore_observation_providers:
|
||||
app_orchestration_config.model_config.stop.append('Observation')
|
||||
|
||||
# override inputs
|
||||
@@ -181,7 +182,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=json.dumps(chunk)
|
||||
content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
@@ -202,6 +203,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
)
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
@@ -255,9 +257,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
tool_response = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
|
||||
tool_parameters=tool_call_args
|
||||
)
|
||||
# transform tool response to llm friendly response
|
||||
tool_response = self.transform_tool_invoke_messages(tool_response)
|
||||
@@ -466,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
action_str='',
|
||||
action=None,
|
||||
observation=None,
|
||||
@@ -546,7 +554,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
|
||||
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
|
||||
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
|
||||
|
||||
return result
|
||||
|
||||
@@ -621,21 +630,24 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
))
|
||||
|
||||
# add assistant message
|
||||
if len(agent_scratchpad) > 0:
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=(agent_scratchpad[-1].thought or '')
|
||||
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
|
||||
))
|
||||
|
||||
# add user message
|
||||
if len(agent_scratchpad) > 0:
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=(agent_scratchpad[-1].observation or ''),
|
||||
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
|
||||
))
|
||||
|
||||
self._is_first_iteration = False
|
||||
|
||||
return prompt_messages
|
||||
elif mode == "completion":
|
||||
# parse agent scratchpad
|
||||
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
|
||||
self._is_first_iteration = False
|
||||
# parse prompt messages
|
||||
return [UserPromptMessage(
|
||||
content=first_prompt.replace("{{instruction}}", instruction)
|
||||
@@ -655,4 +667,4 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
try:
|
||||
return json.dumps(tools, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps(tools)
|
||||
return json.dumps(tools)
|
||||
|
||||
54
api/core/helper/tool_parameter_cache.py
Normal file
54
api/core/helper/tool_parameter_cache.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolParameterCacheType(Enum):
|
||||
PARAMETER = "tool_parameter"
|
||||
|
||||
class ToolParameterCache:
|
||||
def __init__(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType
|
||||
):
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_tool_parameter = redis_client.get(self.cache_key)
|
||||
if cached_tool_parameter:
|
||||
try:
|
||||
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
|
||||
cached_tool_parameter = json.loads(cached_tool_parameter)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_tool_parameter
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, parameters: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
||||
@@ -82,6 +82,8 @@ class HostingConfiguration:
|
||||
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
|
||||
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
|
||||
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
|
||||
]
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -62,7 +63,8 @@ class IndexingRunner:
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
|
||||
processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
@@ -120,7 +122,8 @@ class IndexingRunner:
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
|
||||
processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
@@ -186,7 +189,7 @@ class IndexingRunner:
|
||||
first()
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
@@ -414,9 +417,14 @@ class IndexingRunner:
|
||||
if separator:
|
||||
separator = separator.replace('\\n', '\n')
|
||||
|
||||
if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']:
|
||||
chunk_overlap = segmentation['chunk_overlap']
|
||||
else:
|
||||
chunk_overlap = 0
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=segmentation.get('chunk_overlap', 0),
|
||||
chunk_overlap=chunk_overlap,
|
||||
fixed_separator=separator,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
@@ -643,17 +651,44 @@ class IndexingRunner:
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
chunk_size = 100
|
||||
chunk_size = 10
|
||||
|
||||
embedding_model_type_instance = None
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for i in range(0, len(documents), chunk_size):
|
||||
chunk_documents = documents[i:i + chunk_size]
|
||||
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
|
||||
chunk_documents, dataset,
|
||||
dataset_document, embedding_model_instance,
|
||||
embedding_model_type_instance))
|
||||
|
||||
for i in range(0, len(documents), chunk_size):
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
# update document status to completed
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="completed",
|
||||
extra_update_params={
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: datetime.datetime.utcnow(),
|
||||
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
|
||||
}
|
||||
)
|
||||
|
||||
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
|
||||
embedding_model_instance, embedding_model_type_instance):
|
||||
with flask_app.app_context():
|
||||
# check document is paused
|
||||
self._check_document_paused_status(dataset_document.id)
|
||||
chunk_documents = documents[i:i + chunk_size]
|
||||
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
|
||||
tokens += sum(
|
||||
embedding_model_type_instance.get_num_tokens(
|
||||
@@ -663,9 +698,9 @@ class IndexingRunner:
|
||||
)
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
# load index
|
||||
index_processor.load(dataset, chunk_documents)
|
||||
db.session.add(dataset)
|
||||
|
||||
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
@@ -680,18 +715,7 @@ class IndexingRunner:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
# update document status to completed
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="completed",
|
||||
extra_update_params={
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: datetime.datetime.utcnow(),
|
||||
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
|
||||
}
|
||||
)
|
||||
return tokens
|
||||
|
||||
def _check_document_paused_status(self, document_id: str):
|
||||
indexing_cache_key = 'document_{}_is_paused'.format(document_id)
|
||||
@@ -750,7 +774,7 @@ class IndexingRunner:
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
|
||||
text_docs: list[Document], process_rule: dict) -> list[Document]:
|
||||
text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
@@ -768,7 +792,8 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
|
||||
process_rule=process_rule)
|
||||
process_rule=process_rule, tenant_id=dataset.tenant_id,
|
||||
doc_language=doc_language)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
@@ -47,11 +47,14 @@ class TokenBufferMemory:
|
||||
files, message.app_model_config
|
||||
)
|
||||
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
if not file_objs:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
|
||||
@@ -147,7 +147,7 @@
|
||||
|
||||
- `input` (float) Input price, i.e., Prompt price
|
||||
- `output` (float) Output price, i.e., returned content price
|
||||
- `unit` (float) Pricing unit, e.g., per 100K price is `0.000001`
|
||||
- `unit` (float) Pricing unit, e.g., if the price is meausred in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
|
||||
- `currency` (string) Currency unit
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
||||
@@ -149,7 +149,7 @@
|
||||
|
||||
- `input` (float) 输入单价,即 Prompt 单价
|
||||
- `output` (float) 输出单价,即返回内容单价
|
||||
- `unit` (float) 价格单位,如:每 100K 的单价为 `0.000001`
|
||||
- `unit` (float) 价格单位,如以 1M tokens 计价,则单价对应的单位 token 数为 `0.000001`
|
||||
- `currency` (string) 货币单位
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
||||
@@ -73,8 +73,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
},
|
||||
'type': 'int',
|
||||
'help': {
|
||||
'en_US': 'The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.',
|
||||
'zh_Hans': '要生成的标记的最大数量。请求可以使用最多2048个标记,这些标记在提示和完成之间共享。',
|
||||
'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.',
|
||||
'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。',
|
||||
},
|
||||
'required': False,
|
||||
'default': 64,
|
||||
|
||||
@@ -17,7 +17,7 @@ class ModelType(Enum):
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
TTS = "tts"
|
||||
# TEXT2IMG = "text2img"
|
||||
TEXT2IMG = "text2img"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
@@ -36,6 +36,8 @@ class ModelType(Enum):
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
|
||||
return cls.TTS
|
||||
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
|
||||
return cls.TEXT2IMG
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
@@ -59,10 +61,11 @@ class ModelType(Enum):
|
||||
return 'tts'
|
||||
elif self == self.MODERATION:
|
||||
return 'moderation'
|
||||
elif self == self.TEXT2IMG:
|
||||
return 'text2img'
|
||||
else:
|
||||
raise ValueError(f'invalid model type {self}')
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
@@ -130,7 +133,7 @@ class ModelPropertyKey(Enum):
|
||||
DEFAULT_VOICE = "default_voice"
|
||||
VOICES = "voices"
|
||||
WORD_LIMIT = "word_limit"
|
||||
AUDOI_TYPE = "audio_type"
|
||||
AUDIO_TYPE = "audio_type"
|
||||
MAX_WORKERS = "max_workers"
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||
|
||||
|
||||
class AIModel(ABC):
|
||||
@@ -148,15 +149,7 @@ class AIModel(ABC):
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
position_file_path = os.path.join(provider_model_type_path, '_position.yaml')
|
||||
|
||||
# read _position.yaml file
|
||||
position_map = {}
|
||||
if os.path.exists(position_file_path):
|
||||
with open(position_file_path, encoding='utf-8') as f:
|
||||
positions = yaml.safe_load(f)
|
||||
# convert list to dict with key as model provider name, value as index
|
||||
position_map = {position: index for index, position in enumerate(positions)}
|
||||
position_map = get_position_map(provider_model_type_path)
|
||||
|
||||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
@@ -206,8 +199,7 @@ class AIModel(ABC):
|
||||
model_schemas.append(model_schema)
|
||||
|
||||
# resort model schemas by position
|
||||
if position_map:
|
||||
model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
|
||||
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
|
||||
|
||||
# cache model schemas
|
||||
self.model_schemas = model_schemas
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@@ -7,6 +6,7 @@ import yaml
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
@@ -104,17 +104,10 @@ class ModelProvider(ABC):
|
||||
|
||||
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
||||
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
|
||||
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
model_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
|
||||
and obj != AIModel and obj.__module__ == mod.__name__):
|
||||
model_class = obj
|
||||
break
|
||||
|
||||
mod = import_module_from_source(
|
||||
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
|
||||
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
|
||||
get_subclasses_from_module(mod, AIModel)), None)
|
||||
if not model_class:
|
||||
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from abc import abstractmethod
|
||||
from typing import IO, Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class Text2ImageModel(AIModel):
|
||||
"""
|
||||
Model class for text2img model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.TEXT2IMG
|
||||
|
||||
def invoke(self, model: str, credentials: dict, prompt: str,
|
||||
model_parameters: dict, user: Optional[str] = None) \
|
||||
-> list[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt: prompt for image generation
|
||||
:param model_parameters: model parameters
|
||||
:param user: unique user id
|
||||
|
||||
:return: image bytes
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model, credentials, prompt, model_parameters, user)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict, prompt: str,
|
||||
model_parameters: dict, user: Optional[str] = None) \
|
||||
-> list[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt: prompt for image generation
|
||||
:param model_parameters: model parameters
|
||||
:param user: unique user id
|
||||
|
||||
:return: image bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -94,8 +94,8 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE]
|
||||
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
|
||||
@@ -2,13 +2,17 @@
|
||||
- anthropic
|
||||
- azure_openai
|
||||
- google
|
||||
- nvidia
|
||||
- cohere
|
||||
- bedrock
|
||||
- togetherai
|
||||
- ollama
|
||||
- mistralai
|
||||
- groq
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- xinference
|
||||
- triton_inference_server
|
||||
- zhipuai
|
||||
- baichuan
|
||||
- spark
|
||||
@@ -18,7 +22,7 @@
|
||||
- moonshot
|
||||
- jina
|
||||
- chatglm
|
||||
- xinference
|
||||
- yi
|
||||
- openllm
|
||||
- localai
|
||||
- openai_api_compatible
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
model: claude-3-haiku-20240307
|
||||
label:
|
||||
en_US: claude-3-haiku-20240307
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.25'
|
||||
output: '1.25'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -342,12 +342,20 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
system = ""
|
||||
prompt_message_dicts = []
|
||||
|
||||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
system += message.content + ("\n" if not system else "")
|
||||
else:
|
||||
message.content=message.content.strip()
|
||||
if first_loop:
|
||||
system=message.content
|
||||
first_loop=False
|
||||
else:
|
||||
system+="\n"
|
||||
system+=message.content
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||
|
||||
return system, prompt_message_dicts
|
||||
@@ -424,8 +432,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
if not isinstance(message.content, list):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
else:
|
||||
message_text = ""
|
||||
for sub_message in message.content:
|
||||
if sub_message.type == PromptMessageContentType.TEXT:
|
||||
message_text += f"{human_prompt} {sub_message.data}"
|
||||
elif sub_message.type == PromptMessageContentType.IMAGE:
|
||||
message_text += f"{human_prompt} [IMAGE]"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
if not isinstance(message.content, list):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
else:
|
||||
message_text = ""
|
||||
for sub_message in message.content:
|
||||
if sub_message.type == PromptMessageContentType.TEXT:
|
||||
message_text += f"{ai_prompt} {sub_message.data}"
|
||||
elif sub_message.type == PromptMessageContentType.IMAGE:
|
||||
message_text += f"{ai_prompt} [IMAGE]"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
else:
|
||||
|
||||
@@ -15,10 +15,11 @@ from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPEN
|
||||
class _CommonAzureOpenAI:
|
||||
@staticmethod
|
||||
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||
api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION)
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['openai_api_key'],
|
||||
"azure_endpoint": credentials['openai_api_base'],
|
||||
"api_version": AZURE_OPENAI_API_VERSION,
|
||||
"api_version": api_version,
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
@@ -14,8 +14,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
PriceConfig,
|
||||
)
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-12-01-preview'
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2024-02-15-preview'
|
||||
|
||||
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
||||
rule = ParameterRule(
|
||||
@@ -124,6 +123,65 @@ LLM_BASE_MODELS = [
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-35-turbo-0125',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 16385,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.0005,
|
||||
output=0.0015,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4',
|
||||
entity=AIModelEntity(
|
||||
@@ -274,6 +332,81 @@ LLM_BASE_MODELS = [
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-0125-preview',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.01,
|
||||
output=0.03,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-1106-preview',
|
||||
entity=AIModelEntity(
|
||||
@@ -524,5 +657,172 @@ EMBEDDING_BASE_MODELS = [
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='text-embedding-3-small',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: 8191,
|
||||
ModelPropertyKey.MAX_CHUNKS: 32,
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.00002,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='text-embedding-3-large',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: 8191,
|
||||
ModelPropertyKey.MAX_CHUNKS: 32,
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.00013,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
SPEECH2TEXT_BASE_MODELS = [
|
||||
AzureBaseModel(
|
||||
base_model_name='whisper-1',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model_properties={
|
||||
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
|
||||
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
|
||||
}
|
||||
)
|
||||
)
|
||||
]
|
||||
TTS_BASE_MODELS = [
|
||||
AzureBaseModel(
|
||||
base_model_name='tts-1',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={
|
||||
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
|
||||
ModelPropertyKey.VOICES: [
|
||||
{
|
||||
'mode': 'alloy',
|
||||
'name': 'Alloy',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'echo',
|
||||
'name': 'Echo',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'fable',
|
||||
'name': 'Fable',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'onyx',
|
||||
'name': 'Onyx',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'nova',
|
||||
'name': 'Nova',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'shimmer',
|
||||
'name': 'Shimmer',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
],
|
||||
ModelPropertyKey.WORD_LIMIT: 120,
|
||||
ModelPropertyKey.AUDIO_TYPE: 'mp3',
|
||||
ModelPropertyKey.MAX_WORKERS: 5
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.015,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='tts-1-hd',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={
|
||||
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
|
||||
ModelPropertyKey.VOICES: [
|
||||
{
|
||||
'mode': 'alloy',
|
||||
'name': 'Alloy',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'echo',
|
||||
'name': 'Echo',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'fable',
|
||||
'name': 'Fable',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'onyx',
|
||||
'name': 'Onyx',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'nova',
|
||||
'name': 'Nova',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'shimmer',
|
||||
'name': 'Shimmer',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
],
|
||||
ModelPropertyKey.WORD_LIMIT: 120,
|
||||
ModelPropertyKey.AUDIO_TYPE: 'mp3',
|
||||
ModelPropertyKey.MAX_WORKERS: 5
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.03,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@@ -15,6 +15,8 @@ help:
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
@@ -44,6 +46,22 @@ model_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API key here
|
||||
- variable: openai_api_version
|
||||
label:
|
||||
zh_Hans: API 版本
|
||||
en_US: API Version
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: 2024-02-15-preview
|
||||
value: 2024-02-15-preview
|
||||
- label:
|
||||
en_US: 2023-12-01-preview
|
||||
value: 2023-12-01-preview
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
- variable: base_model_name
|
||||
label:
|
||||
en_US: Base Model
|
||||
@@ -57,6 +75,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-0125
|
||||
value: gpt-35-turbo-0125
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-35-turbo-16k
|
||||
value: gpt-35-turbo-16k
|
||||
@@ -75,6 +99,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-0125-preview
|
||||
value: gpt-4-0125-preview
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-1106-preview
|
||||
value: gpt-4-1106-preview
|
||||
@@ -99,6 +129,36 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: text-embedding-3-small
|
||||
value: text-embedding-3-small
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: text-embedding-3-large
|
||||
value: text-embedding-3-large
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: whisper-1
|
||||
value: whisper-1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: speech2text
|
||||
- label:
|
||||
en_US: tts-1
|
||||
value: tts-1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
- label:
|
||||
en_US: tts-1-hd
|
||||
value: tts-1-hd
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型版本
|
||||
en_US: Enter your model version
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
import copy
|
||||
from typing import IO, Optional
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
|
||||
|
||||
|
||||
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
return self._speech2text_invoke(model, credentials, file)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
self._speech2text_invoke(model, credentials, audio_file)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:return: text for given audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# init model client
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
|
||||
response = client.audio.transcriptions.create(model=model, file=file)
|
||||
|
||||
return response.text
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
174
api/core/model_runtime/model_providers/azure_openai/tts/tts.py
Normal file
174
api/core/model_runtime/model_providers/azure_openai/tts/tts.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import concurrent.futures
|
||||
import copy
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from openai import AzureOpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
self._tts_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text='Hello Dify!',
|
||||
voice=self._get_model_default_voice(model, credentials),
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
|
||||
"""
|
||||
_tts_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
|
||||
credentials=credentials) for sentence in sentences]
|
||||
for future in futures:
|
||||
try:
|
||||
if future.result():
|
||||
audio_bytes_list.append(future.result())
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
if len(audio_bytes_list) > 0:
|
||||
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
||||
audio_bytes_list if audio_bytes]
|
||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _process_sentence(self, sentence: str, model: str,
|
||||
voice, credentials: dict):
|
||||
"""
|
||||
_tts_invoke openai text2speech model api
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param sentence: text content to be translated
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
if isinstance(response.read(), bytes):
|
||||
return response.read()
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in TTS_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
@@ -108,7 +108,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(e)
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
@@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
elif err == 'insufficient_quota':
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
raise InvalidAuthenticationError(msg)
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and 'rate' in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif err and 'internal' in err:
|
||||
|
||||
@@ -18,9 +18,10 @@ class BedrockProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `gemini-pro` model for validate,
|
||||
# Use `amazon.titan-text-lite-v1` model by default for validating credentials
|
||||
model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1')
|
||||
model_instance.validate_credentials(
|
||||
model='amazon.titan-text-lite-v1',
|
||||
model=model_for_validation,
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@@ -48,24 +48,33 @@ provider_credential_schema:
|
||||
- value: us-east-1
|
||||
label:
|
||||
en_US: US East (N. Virginia)
|
||||
zh_Hans: US East (N. Virginia)
|
||||
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||
- value: us-west-2
|
||||
label:
|
||||
en_US: US West (Oregon)
|
||||
zh_Hans: US West (Oregon)
|
||||
zh_Hans: 美国西部 (俄勒冈州)
|
||||
- value: ap-southeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Singapore)
|
||||
zh_Hans: Asia Pacific (Singapore)
|
||||
zh_Hans: 亚太地区 (新加坡)
|
||||
- value: ap-northeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Tokyo)
|
||||
zh_Hans: Asia Pacific (Tokyo)
|
||||
zh_Hans: 亚太地区 (东京)
|
||||
- value: eu-central-1
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
- variable: model_for_validation
|
||||
required: false
|
||||
label:
|
||||
en_US: Available Model Name
|
||||
zh_Hans: 可用模型名称
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
|
||||
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1)
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
- anthropic.claude-v1
|
||||
- anthropic.claude-v2
|
||||
- anthropic.claude-v2:1
|
||||
- anthropic.claude-3-sonnet-v1:0
|
||||
- anthropic.claude-3-haiku-v1:0
|
||||
- cohere.command-light-text-v14
|
||||
- cohere.command-text-v14
|
||||
- meta.llama2-13b-chat-v1
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
model: anthropic.claude-3-haiku-20240307-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Haiku
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,56 @@
|
||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Sonnet
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -1,33 +1,50 @@
|
||||
model: anthropic.claude-instant-v1
|
||||
label:
|
||||
en_US: Claude Instant V1
|
||||
en_US: Claude Instant 1
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 100000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top K
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 250
|
||||
min: 0
|
||||
max: 500
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.0008'
|
||||
output: '0.0024'
|
||||
|
||||
@@ -1,33 +1,50 @@
|
||||
model: anthropic.claude-v1
|
||||
label:
|
||||
en_US: Claude V1
|
||||
en_US: Claude 1
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 100000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top K
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 250
|
||||
min: 0
|
||||
max: 500
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
|
||||
@@ -1,33 +1,50 @@
|
||||
model: anthropic.claude-v2:1
|
||||
label:
|
||||
en_US: Claude V2.1
|
||||
en_US: Claude 2.1
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top K
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 250
|
||||
min: 0
|
||||
max: 500
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
|
||||
@@ -1,33 +1,50 @@
|
||||
model: anthropic.claude-v2
|
||||
label:
|
||||
en_US: Claude V2
|
||||
en_US: Claude 2
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 100000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top K
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 250
|
||||
min: 0
|
||||
max: 500
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
|
||||
@@ -1,9 +1,22 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from anthropic import AnthropicBedrock, Stream
|
||||
from anthropic.types import (
|
||||
ContentBlockDeltaEvent,
|
||||
Message,
|
||||
MessageDeltaEvent,
|
||||
MessageStartEvent,
|
||||
MessageStopEvent,
|
||||
MessageStreamEvent,
|
||||
)
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
@@ -13,14 +26,18 @@ from botocore.exceptions import (
|
||||
UnknownServiceError,
|
||||
)
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
@@ -54,9 +71,293 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
|
||||
# invoke anthropic models via anthropic official SDK
|
||||
if "anthropic" in model:
|
||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
# invoke other models via boto3 client
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke Anthropic large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# use Anthropic official SDK references
|
||||
# - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
client = AnthropicBedrock(
|
||||
aws_access_key=credentials["aws_access_key_id"],
|
||||
aws_secret_key=credentials["aws_secret_access_key"],
|
||||
aws_region=credentials["aws_region"],
|
||||
)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
|
||||
# Notice: If you request the current version of the SDK to the bedrock server,
|
||||
# you will get the following error message and you need to wait for the service or SDK to be updated.
|
||||
# Response: Error code: 400
|
||||
# {'message': 'Malformed input request: #: subject must not be valid against schema
|
||||
# {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
|
||||
# TODO: Open in the future when the interface is properly supported
|
||||
# if user:
|
||||
# ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
|
||||
# extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)
|
||||
|
||||
system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages)
|
||||
|
||||
if system:
|
||||
extra_model_kwargs['system'] = system
|
||||
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_claude_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_claude_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_claude_response(self, model: str, credentials: dict, response: Message,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response chunk generator result
|
||||
"""
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.content[0].text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
|
||||
prompt_messages: list[PromptMessage], ) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
try:
|
||||
full_assistant_content = ''
|
||||
return_model = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = None
|
||||
index = 0
|
||||
|
||||
for chunk in response:
|
||||
if isinstance(chunk, MessageStartEvent):
|
||||
return_model = chunk.message.model
|
||||
input_tokens = chunk.message.usage.input_tokens
|
||||
elif isinstance(chunk, MessageDeltaEvent):
|
||||
output_tokens = chunk.usage.output_tokens
|
||||
finish_reason = chunk.delta.stop_reason
|
||||
elif isinstance(chunk, MessageStopEvent):
|
||||
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
||||
yield LLMResultChunk(
|
||||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index + 1,
|
||||
message=AssistantPromptMessage(
|
||||
content=''
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||
chunk_text = chunk.delta.text if chunk.delta.text else ''
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else '',
|
||||
)
|
||||
index = chunk.index
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
|
||||
def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_tokens: prompt tokens
|
||||
:param completion_tokens: completion tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get prompt price info
|
||||
prompt_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
# get completion price info
|
||||
completion_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.OUTPUT,
|
||||
tokens=completion_tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=prompt_price_info.unit_price,
|
||||
prompt_price_unit=prompt_price_info.unit,
|
||||
prompt_price=prompt_price_info.total_amount,
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=completion_price_info.unit_price,
|
||||
completion_price_unit=completion_price_info.unit,
|
||||
completion_price=completion_price_info.total_amount,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
||||
currency=prompt_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
"""
|
||||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
|
||||
system = ""
|
||||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content=message.content.strip()
|
||||
if first_loop:
|
||||
system=message.content
|
||||
first_loop=False
|
||||
else:
|
||||
system+="\n"
|
||||
system+=message.content
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message))
|
||||
|
||||
return system, prompt_message_dicts
|
||||
|
||||
def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||
|
||||
sub_message_dict = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
@@ -101,7 +402,19 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
if "anthropic.claude-3" in model:
|
||||
try:
|
||||
self._invoke_claude(model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[{"role": "user", "content": "ping"}],
|
||||
model_parameters={},
|
||||
stop=None,
|
||||
stream=False)
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
try:
|
||||
ping_message = UserPromptMessage(content="ping")
|
||||
self._generate(model=model,
|
||||
|
||||
@@ -472,7 +472,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name is not None:
|
||||
if message.name:
|
||||
message_dict["user_name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
<svg width="112" height="24" viewBox="0 0 112 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M57.4336 17.092C56.4746 16.5453 55.7216 15.7924 55.1749 14.8244C54.6283 13.8564 54.3594 12.763 54.3594 11.544C54.3594 10.3251 54.6283 9.2137 55.1749 8.24571C55.7216 7.27772 56.4746 6.52485 57.4336 5.98708C58.3926 5.4493 59.4861 5.18042 60.6961 5.18042C61.6999 5.18042 62.623 5.3776 63.4476 5.77197C64.2722 6.16633 64.9445 6.73995 65.4554 7.49284L64.568 8.13816C64.1199 7.51076 63.5642 7.04469 62.9009 6.731C62.2377 6.41729 61.5027 6.26492 60.705 6.26492C59.7281 6.26492 58.8498 6.48899 58.0789 6.92818C57.2992 7.36736 56.6986 7.98579 56.2505 8.79244C55.8113 9.59014 55.5872 10.5133 55.5872 11.553C55.5872 12.5926 55.8113 13.5159 56.2505 14.3136C56.6896 15.1112 57.2992 15.7297 58.0789 16.1778C58.8587 16.617 59.7281 16.8411 60.705 16.8411C61.5027 16.8411 62.2377 16.6888 62.9009 16.375C63.5642 16.0613 64.1199 15.5953 64.568 14.9678L65.4554 15.6132C64.9445 16.366 64.2722 16.9396 63.4476 17.334C62.623 17.7284 61.7089 17.9255 60.6961 17.9255C59.4771 17.9255 58.3926 17.6568 57.4336 17.11V17.092Z" fill="#F55036"/>
|
||||
<path d="M67.2754 0H68.4763V17.8181H67.2754V0Z" fill="#F55036"/>
|
||||
<path d="M73.6754 17.092C72.7254 16.5454 71.9725 15.7924 71.4347 14.8244C70.888 13.8564 70.6191 12.763 70.6191 11.544C70.6191 10.3251 70.888 9.23163 71.4347 8.26364C71.9814 7.29566 72.7254 6.54277 73.6754 5.99604C74.6255 5.4493 75.6921 5.18042 76.8841 5.18042C78.0762 5.18042 79.1338 5.4493 80.0928 5.99604C81.0429 6.54277 81.7957 7.29566 82.3335 8.26364C82.8803 9.23163 83.1492 10.3251 83.1492 11.544C83.1492 12.763 82.8803 13.8564 82.3335 14.8244C81.7868 15.7924 81.0429 16.5454 80.0928 17.092C79.1427 17.6387 78.0673 17.9076 76.8841 17.9076C75.7011 17.9076 74.6344 17.6387 73.6754 17.092ZM79.4655 16.1599C80.2273 15.7118 80.8277 15.0843 81.2669 14.2867C81.7062 13.489 81.9302 12.5747 81.9302 11.553C81.9302 10.5312 81.7062 9.61703 81.2669 8.81933C80.8277 8.02164 80.2273 7.39425 79.4655 6.9461C78.7036 6.49796 77.8431 6.27389 76.8841 6.27389C75.9251 6.27389 75.0646 6.49796 74.3028 6.9461C73.5409 7.39425 72.9405 8.02164 72.5013 8.81933C72.0621 9.61703 71.838 10.5312 71.838 11.553C71.838 12.5747 72.0621 13.489 72.5013 14.2867C72.9405 15.0843 73.5409 15.7118 74.3028 16.1599C75.0646 16.608 75.9251 16.8322 76.8841 16.8322C77.8431 16.8322 78.7036 16.608 79.4655 16.1599Z" fill="#F55036"/>
|
||||
<path d="M96.2799 5.27905V17.8091H95.1237V15.1203C94.7114 15.9986 94.0929 16.6887 93.2774 17.1728C92.4618 17.6567 91.5027 17.9077 90.4003 17.9077C88.769 17.9077 87.4873 17.4506 86.5553 16.5364C85.6231 15.6222 85.166 14.3136 85.166 12.6017V5.27905H86.367V12.5031C86.367 13.9102 86.7255 14.9858 87.4515 15.7207C88.1775 16.4557 89.1903 16.8232 90.4989 16.8232C91.9061 16.8232 93.0264 16.384 93.851 15.5057C94.6756 14.6272 95.0878 13.4442 95.0878 11.9563V5.27905H96.2889H96.2799Z" fill="#F55036"/>
|
||||
<path d="M110.952 0V17.8181H109.777V14.8604C109.284 15.8374 108.585 16.5902 107.689 17.119C106.793 17.6479 105.78 17.9077 104.642 17.9077C103.503 17.9077 102.419 17.6389 101.469 17.0922C100.528 16.5454 99.7838 15.7925 99.246 14.8336C98.7083 13.8745 98.4395 12.781 98.4395 11.5441C98.4395 10.3073 98.7083 9.2138 99.246 8.24582C99.7838 7.27783 100.519 6.52496 101.469 5.98718C102.41 5.44941 103.468 5.18053 104.642 5.18053C105.816 5.18053 106.766 5.44044 107.653 5.96925C108.541 6.49807 109.24 7.23301 109.75 8.17411V0H110.952ZM107.295 16.16C108.057 15.7119 108.657 15.0844 109.096 14.2868C109.535 13.4891 109.759 12.5749 109.759 11.5531C109.759 10.5313 109.535 9.61713 109.096 8.81944C108.657 8.02174 108.057 7.39434 107.295 6.9462C106.533 6.49807 105.672 6.27399 104.713 6.27399C103.754 6.27399 102.894 6.49807 102.132 6.9462C101.37 7.39434 100.77 8.02174 100.331 8.81944C99.8914 9.61713 99.6673 10.5313 99.6673 11.5531C99.6673 12.5749 99.8914 13.4891 100.331 14.2868C100.77 15.0844 101.37 15.7119 102.132 16.16C102.894 16.6081 103.754 16.8322 104.713 16.8322C105.672 16.8322 106.533 16.6081 107.295 16.16Z" fill="#F55036"/>
|
||||
<path d="M30.6085 5.27024C27.077 5.27024 24.209 8.13835 24.209 11.6697C24.209 15.201 27.077 18.0692 30.6085 18.0692C34.1399 18.0692 37.0079 15.201 37.0079 11.6697C37.0079 8.13835 34.1399 5.27921 30.6085 5.27024ZM30.6085 15.6672C28.4036 15.6672 26.611 13.8746 26.611 11.6697C26.611 9.46486 28.4036 7.67228 30.6085 7.67228C32.8133 7.67228 34.6059 9.46486 34.6059 11.6697C34.6059 13.8746 32.8133 15.6672 30.6085 15.6672Z" fill="black"/>
|
||||
<path d="M6.45358 5.23422C2.92222 5.19837 0.036187 8.0396 0.000335591 11.571C-0.0355158 15.1023 2.80571 17.9974 6.33706 18.0242C6.37292 18.0242 6.41773 18.0242 6.45358 18.0242H8.55986V15.6311H6.45358C4.24873 15.658 2.43823 13.8923 2.41134 11.6785C2.38445 9.47365 4.15014 7.66315 6.36395 7.63626C6.39084 7.63626 6.4267 7.63626 6.45358 7.63626C8.65844 7.63626 10.46 9.42884 10.46 11.6337V17.5222C10.46 19.7092 8.67637 21.4929 6.48943 21.5197C5.44078 21.5197 4.44591 21.0895 3.71095 20.3455L2.01698 22.0395C3.1911 23.2227 4.7865 23.8949 6.45358 23.9128H6.54321C10.0298 23.859 12.8351 21.0357 12.853 17.5491V11.4724C12.7635 8.00374 9.93116 5.23422 6.46254 5.23422H6.45358Z" fill="black"/>
|
||||
<path d="M51.2406 11.5082C51.151 8.03961 48.3187 5.27009 44.8501 5.27009C41.3187 5.23423 38.4237 8.07545 38.3968 11.6068C38.361 15.1382 41.2022 18.0331 44.7335 18.0601C44.7694 18.0601 44.8143 18.0601 44.8501 18.0601H46.9563V15.667H44.8501C42.6452 15.6939 40.8347 13.9282 40.8078 11.7144C40.7809 9.5095 42.5467 7.69902 44.7604 7.67213C44.7874 7.67213 44.8232 7.67213 44.8501 7.67213C47.055 7.67213 48.8565 9.46469 48.8565 11.6696V23.626L51.2406 23.6528V11.5082Z" fill="black"/>
|
||||
<path d="M14.6808 18.0602H17.0649V11.6607C17.0649 9.45589 18.8575 7.66332 21.0623 7.66332C21.7883 7.66332 22.4695 7.8605 23.0611 8.2011L24.2621 6.12172C23.3209 5.57498 22.2276 5.27024 21.0713 5.27024C17.5399 5.27024 14.6719 8.13835 14.6719 11.6697V18.0692L14.6808 18.0602Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 5.8 KiB |
@@ -0,0 +1,4 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="24" height="24" rx="12" fill="#F55036"/>
|
||||
<path d="M12.146 6.00022C9.87734 5.97718 8.02325 7.80249 8.00022 10.0712C7.97718 12.3398 9.80249 14.1997 12.0712 14.217C12.0942 14.217 12.123 14.217 12.146 14.217H13.4992V12.6796H12.146C10.7295 12.6968 9.56641 11.5625 9.54913 10.1403C9.53186 8.72377 10.6662 7.56065 12.0884 7.54337C12.1057 7.54337 12.1287 7.54337 12.146 7.54337C13.5625 7.54337 14.7199 8.69498 14.7199 10.1115V13.8945C14.7199 15.2995 13.574 16.4453 12.169 16.4626C11.4953 16.4626 10.8562 16.1862 10.384 15.7083L9.29578 16.7965C10.0501 17.5566 11.075 17.9885 12.146 18H12.2036C14.4435 17.9654 16.2457 16.1516 16.2572 13.9117V10.0078C16.1997 7.77945 14.3801 6.00022 12.1518 6.00022H12.146Z" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 828 B |
29
api/core/model_runtime/model_providers/groq/groq.py
Normal file
29
api/core/model_runtime/model_providers/groq/groq.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GroqProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='llama2-70b-4096',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
32
api/core/model_runtime/model_providers/groq/groq.yaml
Normal file
32
api/core/model_runtime/model_providers/groq/groq.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
provider: groq
|
||||
label:
|
||||
zh_Hans: GroqCloud
|
||||
en_US: GroqCloud
|
||||
description:
|
||||
en_US: GroqCloud provides access to the Groq Cloud API, which hosts models like LLama2 and Mixtral.
|
||||
zh_Hans: GroqCloud 提供对 Groq Cloud API 的访问,其中托管了 LLama2 和 Mixtral 等模型。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#F5F5F4"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from GroqCloud
|
||||
zh_Hans: 从 GroqCloud 获取 API Key
|
||||
url:
|
||||
en_US: https://console.groq.com/
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@@ -0,0 +1,25 @@
|
||||
model: llama2-70b-4096
|
||||
label:
|
||||
zh_Hans: Llama-2-70B-4096
|
||||
en_US: Llama-2-70B-4096
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.7'
|
||||
output: '0.8'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
26
api/core/model_runtime/model_providers/groq/llm/llm.py
Normal file
26
api/core/model_runtime/model_providers/groq/llm/llm.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.groq.com/openai/v1'
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
model: mixtral-8x7b-32768
|
||||
label:
|
||||
zh_Hans: Mixtral-8x7b-Instruct-v0.1
|
||||
en_US: Mixtral-8x7b-Instruct-v0.1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 20480
|
||||
pricing:
|
||||
input: '0.27'
|
||||
output: '0.27'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,4 @@
|
||||
model: jina-colbert-v1-en
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
@@ -1,20 +1,32 @@
|
||||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class JinaTokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_jina_base(text: str) -> int:
|
||||
_tokenizer = None
|
||||
_lock = Lock()
|
||||
|
||||
@classmethod
|
||||
def _get_tokenizer(cls):
|
||||
if cls._tokenizer is None:
|
||||
with cls._lock:
|
||||
if cls._tokenizer is None:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
||||
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
return cls._tokenizer
|
||||
|
||||
@classmethod
|
||||
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
|
||||
"""
|
||||
use jina tokenizer to get num tokens
|
||||
"""
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
||||
tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
tokenizer = cls._get_tokenizer()
|
||||
tokens = tokenizer.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
return JinaTokenizer._get_num_tokens_by_jina_base(text)
|
||||
@classmethod
|
||||
def get_num_tokens(cls, text: str) -> int:
|
||||
return cls._get_num_tokens_by_jina_base(text)
|
||||
@@ -57,7 +57,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(e)
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from httpx import Timeout
|
||||
from openai import (
|
||||
@@ -19,6 +18,7 @@ from openai import (
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
from openai.types.completion import Completion
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
@@ -181,7 +181,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
UserPromptMessage(content='ping')
|
||||
], model_parameters={
|
||||
'max_tokens': 10,
|
||||
}, stop=[])
|
||||
}, stop=[], stream=False)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
|
||||
|
||||
@@ -227,6 +227,12 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
)
|
||||
]
|
||||
|
||||
model_properties = {
|
||||
ModelPropertyKey.MODE: completion_model,
|
||||
} if completion_model else {}
|
||||
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
@@ -234,7 +240,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
|
||||
model_properties=model_properties,
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
@@ -319,7 +325,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
client_kwargs = {
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"api_key": "1",
|
||||
"base_url": urljoin(credentials['server_url'], 'v1'),
|
||||
"base_url": str(URL(credentials['server_url']) / 'v1'),
|
||||
}
|
||||
|
||||
return client_kwargs
|
||||
|
||||
@@ -56,3 +56,12 @@ model_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
|
||||
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 上下文大小
|
||||
en_US: Context size
|
||||
placeholder:
|
||||
zh_Hans: 输入上下文大小
|
||||
en_US: Enter context size
|
||||
required: false
|
||||
type: text-input
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from os.path import join
|
||||
from typing import Optional
|
||||
|
||||
from requests import post
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
@@ -57,9 +58,9 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
}
|
||||
|
||||
try:
|
||||
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
|
||||
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(e)
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
@@ -113,6 +114,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
# use GPT2Tokenizer to get num tokens
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(zh_Hans=model, en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
features=[],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
|
||||
@@ -65,7 +65,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(e)
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeServerUnavailableError(response.text)
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 2.3 KiB After Width: | Height: | Size: 7.2 KiB |
@@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 2048
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@@ -12,6 +9,8 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -200,7 +199,6 @@ class ModelProviderFactory:
|
||||
if self.model_provider_extensions:
|
||||
return self.model_provider_extensions
|
||||
|
||||
model_providers = {}
|
||||
|
||||
# get the path of current classes
|
||||
current_path = os.path.abspath(__file__)
|
||||
@@ -215,17 +213,10 @@ class ModelProviderFactory:
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
position_file_path = os.path.join(model_providers_path, '_position.yaml')
|
||||
|
||||
# read _position.yaml file
|
||||
position_map = {}
|
||||
if os.path.exists(position_file_path):
|
||||
with open(position_file_path, encoding='utf-8') as f:
|
||||
positions = yaml.safe_load(f)
|
||||
# convert list to dict with key as model provider name, value as index
|
||||
position_map = {position: index for index, position in enumerate(positions)}
|
||||
position_map = get_position_map(model_providers_path)
|
||||
|
||||
# traverse all model_provider_dir_paths
|
||||
model_providers: list[ModelProviderExtension] = []
|
||||
for model_provider_dir_path in model_provider_dir_paths:
|
||||
# get model_provider dir name
|
||||
model_provider_name = os.path.basename(model_provider_dir_path)
|
||||
@@ -238,15 +229,10 @@ class ModelProviderFactory:
|
||||
|
||||
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
|
||||
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
model_provider_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider:
|
||||
model_provider_class = obj
|
||||
break
|
||||
model_provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
|
||||
script_path=py_path,
|
||||
parent_type=ModelProvider)
|
||||
|
||||
if not model_provider_class:
|
||||
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
|
||||
@@ -256,14 +242,13 @@ class ModelProviderFactory:
|
||||
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
||||
continue
|
||||
|
||||
model_providers[model_provider_name] = ModelProviderExtension(
|
||||
model_providers.append(ModelProviderExtension(
|
||||
name=model_provider_name,
|
||||
provider_instance=model_provider_class(),
|
||||
position=position_map.get(model_provider_name)
|
||||
)
|
||||
))
|
||||
|
||||
sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position))
|
||||
sorted_extensions = OrderedDict(sorted_items)
|
||||
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
|
||||
|
||||
self.model_provider_extensions = sorted_extensions
|
||||
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
@@ -0,0 +1,3 @@
|
||||
<svg width="567" height="376" viewBox="0 0 567 376" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M58.0366 161.868C58.0366 161.868 109.261 86.2912 211.538 78.4724V51.053C98.2528 60.1511 0.152344 156.098 0.152344 156.098C0.152344 156.098 55.7148 316.717 211.538 331.426V302.282C97.1876 287.896 58.0366 161.868 58.0366 161.868ZM211.538 244.32V271.013C125.114 255.603 101.125 165.768 101.125 165.768C101.125 165.768 142.621 119.799 211.538 112.345V141.633C211.486 141.633 211.449 141.617 211.406 141.617C175.235 137.276 146.978 171.067 146.978 171.067C146.978 171.067 162.816 227.949 211.538 244.32ZM211.538 0.47998V51.053C214.864 50.7981 218.189 50.5818 221.533 50.468C350.326 46.1273 434.243 156.098 434.243 156.098C434.243 156.098 337.861 273.296 237.448 273.296C228.245 273.296 219.63 272.443 211.538 271.009V302.282C218.695 303.201 225.903 303.667 233.119 303.675C326.56 303.675 394.134 255.954 459.566 199.474C470.415 208.162 514.828 229.299 523.958 238.55C461.745 290.639 316.752 332.626 234.551 332.626C226.627 332.626 219.018 332.148 211.538 331.426V375.369H566.701V0.47998H211.538ZM211.538 112.345V78.4724C214.829 78.2425 218.146 78.0672 221.533 77.9602C314.148 75.0512 374.909 157.548 374.909 157.548C374.909 157.548 309.281 248.693 238.914 248.693C228.787 248.693 219.707 247.065 211.536 244.318V141.631C247.591 145.987 254.848 161.914 276.524 198.049L324.737 157.398C324.737 157.398 289.544 111.243 230.219 111.243C223.768 111.241 217.597 111.696 211.538 112.345Z" fill="#77B900"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
@@ -0,0 +1,4 @@
|
||||
- google/gemma-7b
|
||||
- meta/llama2-70b
|
||||
- mistralai/mixtral-8x7b-instruct-v0.1
|
||||
- fuyu-8b
|
||||
@@ -0,0 +1,27 @@
|
||||
model: fuyu-8b
|
||||
label:
|
||||
zh_Hans: fuyu-8b
|
||||
en_US: fuyu-8b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.2
|
||||
min: 0.1
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.7
|
||||
min: 0.1
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
@@ -0,0 +1,30 @@
|
||||
model: google/gemma-7b
|
||||
label:
|
||||
zh_Hans: google/gemma-7b
|
||||
en_US: google/gemma-7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
@@ -0,0 +1,30 @@
|
||||
model: meta/llama2-70b
|
||||
label:
|
||||
zh_Hans: meta/llama2-70b
|
||||
en_US: meta/llama2-70b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
247
api/core/model_runtime/model_providers/nvidia/llm/llm.py
Normal file
247
api/core/model_runtime/model_providers/nvidia/llm/llm.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageFunction,
|
||||
PromptMessageTool,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
|
||||
class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
MODEL_SUFFIX_MAP = {
|
||||
'fuyu-8b': 'vlm/adept/fuyu-8b',
|
||||
'mistralai/mixtral-8x7b-instruct-v0.1': '',
|
||||
'google/gemma-7b': '',
|
||||
'meta/llama2-70b': ''
|
||||
}
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
self._add_custom_parameters(credentials, model)
|
||||
prompt_messages = self._transform_prompt_messages(prompt_messages)
|
||||
stop = []
|
||||
user = None
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _transform_prompt_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle Image transform
|
||||
"""
|
||||
for i, p in enumerate(prompt_messages):
|
||||
if isinstance(p, UserPromptMessage) and isinstance(p.content, list):
|
||||
content = p.content
|
||||
content_text = ''
|
||||
for prompt_content in content:
|
||||
if prompt_content.type == PromptMessageContentType.TEXT:
|
||||
content_text += prompt_content.data
|
||||
else:
|
||||
content_text += f' <img src="{prompt_content.data}" />'
|
||||
|
||||
prompt_message = UserPromptMessage(
|
||||
content=content_text
|
||||
)
|
||||
prompt_messages[i] = prompt_message
|
||||
return prompt_messages
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, model)
|
||||
self._validate_credentials(model, credentials)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict, model: str) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
|
||||
if self.MODEL_SUFFIX_MAP[model]:
|
||||
credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}'
|
||||
credentials.pop('endpoint_url')
|
||||
else:
|
||||
credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1'
|
||||
|
||||
credentials['stream_mode_delimiter'] = '\n'
|
||||
|
||||
def _validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
|
||||
if endpoint_url and not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
server_url = credentials['server_url'] if 'server_url' in credentials else None
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {
|
||||
'model': model,
|
||||
'max_tokens': 5
|
||||
}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
data['messages'] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "ping"
|
||||
},
|
||||
]
|
||||
if 'endpoint_url' in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions')
|
||||
elif 'server_url' in credentials:
|
||||
endpoint_url = server_url
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
data['prompt'] = 'ping'
|
||||
if 'endpoint_url' in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / 'completions')
|
||||
elif 'server_url' in credentials:
|
||||
endpoint_url = server_url
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed with status code {response.status_code}')
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, \
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm completion model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept-Charset': 'utf-8',
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
if stream:
|
||||
headers['Accept'] = 'text/event-stream'
|
||||
|
||||
endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None
|
||||
if endpoint_url and not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
server_url = credentials['server_url'] if 'server_url' in credentials else None
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
**model_parameters
|
||||
}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
if 'endpoint_url' in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions')
|
||||
elif 'server_url' in credentials:
|
||||
endpoint_url = server_url
|
||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
data['prompt'] = 'ping'
|
||||
if 'endpoint_url' in credentials:
|
||||
endpoint_url = str(URL(endpoint_url) / 'completions')
|
||||
elif 'server_url' in credentials:
|
||||
endpoint_url = server_url
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
|
||||
# annotate tools with names, descriptions, etc.
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
formatted_tools = []
|
||||
if tools:
|
||||
if function_calling_type == 'function_call':
|
||||
data['functions'] = [{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
} for tool in tools]
|
||||
elif function_calling_type == 'tool_call':
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
for tool in tools:
|
||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
|
||||
data["tools"] = formatted_tools
|
||||
|
||||
if stop:
|
||||
data["stop"] = stop
|
||||
|
||||
if user:
|
||||
data["user"] = user
|
||||
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
if response.encoding is None or response.encoding == 'ISO-8859-1':
|
||||
response.encoding = 'utf-8'
|
||||
|
||||
if not response.ok:
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
@@ -0,0 +1,30 @@
|
||||
model: mistralai/mixtral-8x7b-instruct-v0.1
|
||||
label:
|
||||
zh_Hans: mistralai/mixtral-8x7b-instruct-v0.1
|
||||
en_US: mistralai/mixtral-8x7b-instruct-v0.1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
30
api/core/model_runtime/model_providers/nvidia/nvidia.py
Normal file
30
api/core/model_runtime/model_providers/nvidia/nvidia.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MistralAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='mistralai/mixtral-8x7b-instruct-v0.1',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
30
api/core/model_runtime/model_providers/nvidia/nvidia.yaml
Normal file
30
api/core/model_runtime/model_providers/nvidia/nvidia.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
provider: nvidia
|
||||
label:
|
||||
en_US: NVIDIA
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from NVIDIA
|
||||
zh_Hans: 从 NVIDIA 获取 API Key
|
||||
url:
|
||||
en_US: https://build.nvidia.com/explore/discover
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@@ -0,0 +1,4 @@
|
||||
model: nv-rerank-qa-mistral-4b:1
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
112
api/core/model_runtime/model_providers/nvidia/rerank/rerank.py
Normal file
112
api/core/model_runtime/model_providers/nvidia/rerank/rerank.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from math import exp
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class NvidiaRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for NVIDIA rerank model.
|
||||
"""
|
||||
|
||||
def _sigmoid(self, logit: float) -> float:
|
||||
return 1/(1+exp(-logit))
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
try:
|
||||
invoke_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.get('api_key')}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": model,
|
||||
"query": {"text": query},
|
||||
"passages": [{"text": doc} for doc in docs],
|
||||
}
|
||||
|
||||
session = requests.Session()
|
||||
response = session.post(invoke_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results['rankings']:
|
||||
index = result['index']
|
||||
logit = result['logit']
|
||||
rerank_document = RerankDocument(
|
||||
index=index,
|
||||
text=docs[index],
|
||||
score=self._sigmoid(logit),
|
||||
)
|
||||
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except requests.HTTPError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the GPU memory bandwidth of H100 SXM?",
|
||||
docs=[
|
||||
"Example doc 1",
|
||||
"Example doc 2",
|
||||
"Example doc 3",
|
||||
],
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [requests.ConnectionError],
|
||||
InvokeServerUnavailableError: [requests.HTTPError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [requests.HTTPError],
|
||||
InvokeBadRequestError: [requests.RequestException]
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
model: NV-Embed-QA
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 1
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user