mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 21:44:07 +00:00
Compare commits
199 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce5b19d011 | ||
|
|
f82a64d149 | ||
|
|
f49b1afd6c | ||
|
|
796c5626a7 | ||
|
|
e54c9cd401 | ||
|
|
f8951d7f57 | ||
|
|
6454e1d644 | ||
|
|
e184c8cb42 | ||
|
|
fdd211e399 | ||
|
|
7001e21e7d | ||
|
|
82d0732c12 | ||
|
|
53cd125780 | ||
|
|
3c91f9b5ab | ||
|
|
f073dca22a | ||
|
|
8b1e35d7dc | ||
|
|
b75d8ca621 | ||
|
|
9beefd7d5a | ||
|
|
88145efa97 | ||
|
|
bdc13f9238 | ||
|
|
ce58f0607b | ||
|
|
bbc0d330a9 | ||
|
|
60e7e17c86 | ||
|
|
237bb8514e | ||
|
|
bd26c933d2 | ||
|
|
b6b58da2d2 | ||
|
|
40c646cf7a | ||
|
|
3231a8c51c | ||
|
|
4170d6a491 | ||
|
|
0b50c525cf | ||
|
|
8ba38e8e74 | ||
|
|
b163545771 | ||
|
|
c0b82f8e58 | ||
|
|
b75ff5fa03 | ||
|
|
9440d7fe88 | ||
|
|
24809fce07 | ||
|
|
9819ad347f | ||
|
|
8fe83750b7 | ||
|
|
1809f05904 | ||
|
|
0ac250a035 | ||
|
|
405a00bb2c | ||
|
|
3a3ca8e6a9 | ||
|
|
27e678480e | ||
|
|
7052565380 | ||
|
|
31070ffbca | ||
|
|
7f3dec7bee | ||
|
|
b1e0db4944 | ||
|
|
c439952a41 | ||
|
|
2f28afebb6 | ||
|
|
fa7ba30ba3 | ||
|
|
1cf5f510ed | ||
|
|
526c874caa | ||
|
|
f88f744097 | ||
|
|
95733796f0 | ||
|
|
552f319b9d | ||
|
|
38e5952417 | ||
|
|
7f891939f1 | ||
|
|
69a5ce1e31 | ||
|
|
534802b761 | ||
|
|
5c258e212c | ||
|
|
6a6133c102 | ||
|
|
3c1825187a | ||
|
|
8523b34be7 | ||
|
|
65cfd4360a | ||
|
|
bbf5f42c87 | ||
|
|
3631e53ff0 | ||
|
|
f322d9bddb | ||
|
|
05ce7b9d5e | ||
|
|
72ddedfc5c | ||
|
|
36686d7425 | ||
|
|
34387ec0f1 | ||
|
|
83a6b0c626 | ||
|
|
76da66fb7e | ||
|
|
607f9eda35 | ||
|
|
f25cec265d | ||
|
|
8e66b96221 | ||
|
|
b5c1bb346c | ||
|
|
e94b323e6c | ||
|
|
bc65ee10c0 | ||
|
|
2001483659 | ||
|
|
444aba55dd | ||
|
|
3f640b1037 | ||
|
|
b07084711c | ||
|
|
fa8ab2134f | ||
|
|
1a677da792 | ||
|
|
b6d61a818e | ||
|
|
8495ffaa45 | ||
|
|
dbd1d79770 | ||
|
|
1910178199 | ||
|
|
839a6a2c8a | ||
|
|
a769edbc89 | ||
|
|
57ffecd0e5 | ||
|
|
801d135390 | ||
|
|
0428f44113 | ||
|
|
7beff3fd5a | ||
|
|
88a095e40e | ||
|
|
dd961985f0 | ||
|
|
d44b05a9e5 | ||
|
|
5bd3b02be6 | ||
|
|
3cf5c1853d | ||
|
|
a4d86496e1 | ||
|
|
90bdc85f8c | ||
|
|
0828873b52 | ||
|
|
816b707a16 | ||
|
|
c9257ab4bf | ||
|
|
69ce3b3d33 | ||
|
|
c4caa7c401 | ||
|
|
dc93a292c3 | ||
|
|
174ee1b646 | ||
|
|
9b1c4f47fb | ||
|
|
582ba45c00 | ||
|
|
f1cbd55007 | ||
|
|
3a34370422 | ||
|
|
29ab244de6 | ||
|
|
920b2c2b40 | ||
|
|
ac96d192a6 | ||
|
|
07fbeb6cf0 | ||
|
|
fc64cdee64 | ||
|
|
0c0e96c55f | ||
|
|
5b953c1ef2 | ||
|
|
562ca45e07 | ||
|
|
6bbd53512e | ||
|
|
e352a8ed1b | ||
|
|
e55225e2bc | ||
|
|
3e63abd335 | ||
|
|
0620fa3094 | ||
|
|
d93288f711 | ||
|
|
ca69af7b97 | ||
|
|
952e13fef8 | ||
|
|
4be3087642 | ||
|
|
49da8a23a8 | ||
|
|
3ad943a9eb | ||
|
|
3082093293 | ||
|
|
b03bbab5ad | ||
|
|
9574730050 | ||
|
|
91ea6fe4ee | ||
|
|
769be13189 | ||
|
|
e42175241e | ||
|
|
12257b438b | ||
|
|
9ecc736c30 | ||
|
|
6c4e6bf1d6 | ||
|
|
97fe817186 | ||
|
|
52b12ed7eb | ||
|
|
d8ab4474b4 | ||
|
|
1ecbd95adf | ||
|
|
cad6e6624f | ||
|
|
3505cbe05c | ||
|
|
e15359e589 | ||
|
|
edb86f5f5a | ||
|
|
adf2651d1f | ||
|
|
5031d64e28 | ||
|
|
ae3ad59b16 | ||
|
|
e6cd7b0467 | ||
|
|
97e9f52331 | ||
|
|
25957d917a | ||
|
|
20b932da97 | ||
|
|
207080babc | ||
|
|
48bacd01cc | ||
|
|
297d0f1f30 | ||
|
|
eedbe1b770 | ||
|
|
5ff6b1da07 | ||
|
|
8b49e0ee2a | ||
|
|
e031ec9359 | ||
|
|
1bd1cd6938 | ||
|
|
81c5a21b8d | ||
|
|
61e4bbabaf | ||
|
|
4cf475680d | ||
|
|
ca4aa340f6 | ||
|
|
767d8a4b05 | ||
|
|
0b8dcaba8f | ||
|
|
af6a318aae | ||
|
|
c6e2900be7 | ||
|
|
963d9b6032 | ||
|
|
b2ee738bb1 | ||
|
|
c8ca3ff404 | ||
|
|
5d8fa2c7af | ||
|
|
58df5e5376 | ||
|
|
348ad1a624 | ||
|
|
73e17d5aa8 | ||
|
|
300d9892a5 | ||
|
|
e47b5b43b8 | ||
|
|
21c9d9e200 | ||
|
|
4f6916c4d8 | ||
|
|
8633957726 | ||
|
|
0850c953b3 | ||
|
|
23e95fd7ab | ||
|
|
e1045f01c6 | ||
|
|
e6d22fc3a0 | ||
|
|
9232244920 | ||
|
|
476eb90a90 | ||
|
|
063191889d | ||
|
|
589099a005 | ||
|
|
a0ec7de058 | ||
|
|
14a19a3da9 | ||
|
|
1b04382a9b | ||
|
|
71e5828d41 | ||
|
|
65a02f7d32 | ||
|
|
acf9174bef | ||
|
|
243ca5b1e2 | ||
|
|
f6059c377c |
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
4
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
4
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
|
||||
30
.github/pull_request_template.md
vendored
Normal file
30
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# Description
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
|
||||
Fixes # (issue)
|
||||
|
||||
## Type of Change
|
||||
|
||||
Please delete options that are not relevant.
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
|
||||
# How Has This Been Tested?
|
||||
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
|
||||
|
||||
- [ ] TODO
|
||||
|
||||
# Suggested Checklist:
|
||||
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [ ] `optional` I have made corresponding changes to the documentation
|
||||
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] `optional` New and existing unit tests pass locally with my changes
|
||||
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@@ -41,6 +41,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
@@ -60,11 +62,10 @@ jobs:
|
||||
yarn run lint
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v5
|
||||
uses: super-linter/super-linter/slim@v6
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: main
|
||||
ERROR_ON_MISSING_EXEC_BIT: true
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
IGNORE_GITIGNORED_FILES: true
|
||||
|
||||
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Run Unit Test For SDKs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [16, 18, 20]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
cache-dependency-path: 'yarn.lock'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: yarn install
|
||||
|
||||
- name: Test
|
||||
run: yarn test
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -145,6 +145,9 @@ docker/volumes/db/data/*
|
||||
docker/volumes/redis/data/*
|
||||
docker/volumes/weaviate/*
|
||||
docker/volumes/qdrant/*
|
||||
docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
||||
22
LICENSE
22
LICENSE
@@ -1,24 +1,26 @@
|
||||
# Dify Open Source License
|
||||
# Open Source License
|
||||
|
||||
The Dify project is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
Dify is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
|
||||
1. Dify is permitted to be used for commercialization, such as using Dify as a "backend-as-a-service" for your other applications, or delivering it to enterprises as an application development platform. However, when the following conditions are met, you must contact the producer to obtain a commercial license:
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify.AI source code to operate a multi-tenant SaaS service that is similar to the Dify.AI service edition.
|
||||
b. LOGO and copyright information: In the process of using Dify, you may not remove or modify the LOGO or copyright information in the Dify console.
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
|
||||
|
||||
Please contact business@dify.ai by email to inquire about licensing matters.
|
||||
|
||||
2. As a contributor, you should agree that your contributed code:
|
||||
2. As a contributor, you should agree that:
|
||||
|
||||
a. The producer can adjust the open-source agreement to be more strict or relaxed.
|
||||
b. Can be used for commercial purposes, such as Dify's cloud business.
|
||||
a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary.
|
||||
b. Your contributed code may be used for commercial purposes, including but not limited to its cloud business operations.
|
||||
|
||||
Apart from this, all other rights and restrictions follow the Apache License 2.0. If you need more detailed information, you can refer to the full version of Apache License 2.0.
|
||||
Apart from the specific conditions mentioned above, all other rights and restrictions follow the Apache License 2.0. Detailed information about the Apache License 2.0 can be found at http://www.apache.org/licenses/LICENSE-2.0.
|
||||
|
||||
The interactive design of this product is protected by appearance patent.
|
||||
|
||||
© 2023 LangGenius, Inc.
|
||||
© 2024 LangGenius, Inc.
|
||||
|
||||
|
||||
----------
|
||||
|
||||
@@ -81,11 +81,17 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
|
||||
# Mail configuration, support: resend
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
SMTP_SERVER=smtp.gmail.com
|
||||
SMTP_PORT=587
|
||||
SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=false
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
@@ -124,3 +130,5 @@ UNSTRUCTURED_API_URL=
|
||||
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
1. Start the docker-compose stack
|
||||
|
||||
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
|
||||
```bash
|
||||
cd ../docker
|
||||
docker-compose -f docker-compose.middleware.yaml -p dify up -d
|
||||
@@ -15,7 +15,7 @@
|
||||
3. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
```bash
|
||||
openssl rand -base64 42
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
3.5 If you use annaconda, create a new environment and activate it
|
||||
```bash
|
||||
@@ -46,7 +46,7 @@
|
||||
```
|
||||
pip install -r requirements.txt --upgrade --force-reinstall
|
||||
```
|
||||
|
||||
|
||||
6. Start backend:
|
||||
```bash
|
||||
flask run --host 0.0.0.0 --port=5001 --debug
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
@@ -27,6 +26,7 @@ from config import CloudEditionConfig, Config
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_login,
|
||||
@@ -39,10 +39,11 @@ from extensions import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
||||
@@ -96,6 +97,7 @@ def create_app(test_config=None) -> Flask:
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
|
||||
296
api/commands.py
296
api/commands.py
@@ -6,16 +6,16 @@ import click
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.model import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
|
||||
|
||||
@@ -124,14 +124,124 @@ def reset_encrypt_key_pair():
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
|
||||
|
||||
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
|
||||
def create_qdrant_indexes():
|
||||
"""
|
||||
Migrate other vector database datas to Qdrant.
|
||||
"""
|
||||
click.echo(click.style('Start create qdrant indexes.', fg='green'))
|
||||
create_count = 0
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in ['knowledge', 'all']:
|
||||
migrate_knowledge_vector_database()
|
||||
if scope in ['annotation', 'all']:
|
||||
migrate_annotation_vector_database()
|
||||
|
||||
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate annotation data.', fg='green'))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
# get apps info
|
||||
apps = db.session.query(App).filter(
|
||||
App.status == 'normal'
|
||||
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for app in apps:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} app {app.id}. '
|
||||
+ f'{create_count} created, {skipped_count} skipped.')
|
||||
try:
|
||||
click.echo('Create app annotation index: {}'.format(app.id))
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo('App annotation setting is disabled: {}'.format(app.id))
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
).first()
|
||||
if not dataset_collection_binding:
|
||||
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
|
||||
continue
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
documents = []
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app.id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index for app: {app.id}.',
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index for app {app.id}.',
|
||||
fg='red'))
|
||||
raise e
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
|
||||
fg='green'))
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
|
||||
raise e
|
||||
click.echo(f'Successfully migrated app annotation {app.id}.')
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
|
||||
fg='green'))
|
||||
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
@@ -140,60 +250,128 @@ def create_qdrant_indexes():
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.create_qdrant_dataset(dataset)
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {
|
||||
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
|
||||
+ f'{create_count} created, ${skipped_count} skipped.')
|
||||
try:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] == vector_type:
|
||||
skipped_count = skipped_count + 1
|
||||
continue
|
||||
collection_name = ''
|
||||
if vector_type == "weaviate":
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "qdrant":
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
elif vector_type == "milvus":
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"Start to migrate dataset {dataset.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='red'))
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
segments_count = 0
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
segments_count = segments_count + 1
|
||||
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
|
||||
raise e
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
click.echo(f'Successfully migrated dataset {dataset.id}.')
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
|
||||
fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(vdb_migrate)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
@@ -39,7 +38,9 @@ DEFAULTS = {
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
@@ -57,6 +58,8 @@ DEFAULTS = {
|
||||
'BILLING_ENABLED': 'False',
|
||||
'CAN_REPLACE_LOGO': 'False',
|
||||
'ETL_TYPE': 'dify',
|
||||
'KEYWORD_STORE': 'jieba',
|
||||
'BATCH_UPLOAD_LIMIT': 20
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +90,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.4"
|
||||
self.CURRENT_VERSION = "0.5.9"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -183,7 +186,7 @@ class Config:
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
@@ -209,6 +212,12 @@ class Config:
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
self.RESEND_API_URL = get_env('RESEND_API_URL')
|
||||
# SMTP settings
|
||||
self.SMTP_SERVER = get_env('SMTP_SERVER')
|
||||
self.SMTP_PORT = get_env('SMTP_PORT')
|
||||
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
|
||||
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
|
||||
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
|
||||
|
||||
# ------------------------
|
||||
# Workpace Configurations.
|
||||
@@ -254,8 +263,10 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
|
||||
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
@@ -280,6 +291,10 @@ class Config:
|
||||
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
|
||||
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
|
||||
|
||||
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
|
||||
|
||||
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
|
||||
import json
|
||||
|
||||
from models.model import AppModelConfig
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
@@ -16,8 +15,10 @@ language_timezone_mapping = {
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
}
|
||||
|
||||
|
||||
def supported_language(lang):
|
||||
if lang in languages:
|
||||
return lang
|
||||
@@ -26,6 +27,7 @@ def supported_language(lang):
|
||||
.format(lang=lang))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
@@ -67,6 +69,16 @@ user_input_form_template = {
|
||||
}
|
||||
}
|
||||
],
|
||||
"ua-UK": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
demo_model_templates = {
|
||||
@@ -145,7 +157,7 @@ demo_model_templates = {
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
@@ -272,7 +284,7 @@ demo_model_templates = {
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
@@ -323,5 +335,130 @@ demo_model_templates = {
|
||||
)
|
||||
}
|
||||
],
|
||||
'uk-UA': [{
|
||||
"name": "Помічник перекладу",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
|
||||
"mode": "completion",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo-instruct",
|
||||
configs={
|
||||
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
|
||||
"prompt_variables": [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Цільова мова",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"type": "select",
|
||||
"default": "Ukrainian",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"English",
|
||||
"Japanese",
|
||||
"French",
|
||||
"Russian",
|
||||
"German",
|
||||
"Spanish",
|
||||
"Korean",
|
||||
"Italian",
|
||||
],
|
||||
},
|
||||
],
|
||||
"completion_params": {
|
||||
"max_token": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Цільова мова",
|
||||
"variable": "target_language",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
"name": "AI інтерв’юер фронтенду",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.",
|
||||
"mode": "chat",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
configs={
|
||||
"introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
"prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
"prompt_variables": [],
|
||||
"completion_params": {
|
||||
"max_token": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=None
|
||||
),
|
||||
}
|
||||
],
|
||||
|
||||
}
|
||||
|
||||
@@ -13,30 +13,14 @@ model_templates = {
|
||||
'status': 'normal'
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'gpt-3.5-turbo-instruct',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 512,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
}
|
||||
},
|
||||
'provider': '',
|
||||
'model_id': '',
|
||||
'configs': {},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
"completion_params": {}
|
||||
}),
|
||||
'user_input_form': json.dumps([
|
||||
{
|
||||
@@ -64,30 +48,14 @@ model_templates = {
|
||||
'status': 'normal'
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'gpt-3.5-turbo',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 512,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
}
|
||||
},
|
||||
'provider': '',
|
||||
'model_id': '',
|
||||
'configs': {},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
"completion_params": {}
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
@@ -28,7 +27,9 @@ from fields.app_fields import (
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
|
||||
def _get_app(app_id, tenant_id):
|
||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||
@@ -125,19 +126,13 @@ class AppListApi(Resource):
|
||||
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
|
||||
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
|
||||
if provider_model not in available_models_names:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
if not default_model_entity:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Default System Reasoning Model available. Please configure "
|
||||
"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
model_config_dict["model"]["provider"] = default_model_entity.provider.provider
|
||||
model_config_dict["model"]["name"] = default_model_entity.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@@ -243,7 +238,42 @@ class AppApi(Resource):
|
||||
def get(self, app_id):
|
||||
"""Get app detail"""
|
||||
app_id = str(app_id)
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
app: App = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
# get original app model config
|
||||
model_config: AppModelConfig = app.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@@ -46,7 +45,8 @@ class ChatMessageAudioApi(Resource):
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file
|
||||
file=file,
|
||||
end_user=None,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -72,7 +72,7 @@ class ChatMessageAudioApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@@ -83,10 +83,12 @@ class ChatMessageTextApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, None)
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@@ -113,9 +115,50 @@ class ChatMessageTextApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextModesApi(Resource):
|
||||
def get(self, app_id: str):
|
||||
app_model = _get_app(str(app_id))
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('language', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args['language'],
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError:
|
||||
raise AppUnavailableError("Text to audio voices language parameter loss.")
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
@@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@@ -8,6 +8,9 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
@@ -39,6 +42,88 @@ class ModelConfigResource(Resource):
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
# get original app model config
|
||||
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == app.app_model_config_id
|
||||
).first()
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
parameter_map = {}
|
||||
masked_parameter_map = {}
|
||||
tool_map = {}
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
parameters = {}
|
||||
masked_parameter = {}
|
||||
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
masked_parameter_map[key] = masked_parameter
|
||||
parameter_map[key] = parameters
|
||||
tool_map[key] = tool_runtime
|
||||
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
|
||||
# get tool
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
if key in tool_map:
|
||||
tool_runtime = tool_map[key]
|
||||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
# override parameters if it equals to masked parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
|
||||
# update app model config
|
||||
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_login
|
||||
from flask import current_app, request
|
||||
from flask_restful import Resource, reqparse
|
||||
@@ -8,7 +7,7 @@ from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
|
||||
class LoginApi(Resource):
|
||||
@@ -30,6 +29,8 @@ class LoginApi(Resource):
|
||||
except services.errors.account.AccountLoginError:
|
||||
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
# todo: return the user info
|
||||
|
||||
@@ -10,7 +10,7 @@ from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
||||
from .. import api
|
||||
|
||||
@@ -76,6 +76,8 @@ class OAuthCallback(Resource):
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.login import login_required
|
||||
@@ -173,14 +174,15 @@ class DataSourceNotionApi(Resource):
|
||||
if not data_source_binding:
|
||||
raise NotFound('Data source binding not found.')
|
||||
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
text_docs = loader.load()
|
||||
text_docs = extractor.extract()
|
||||
return {
|
||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||
}, 200
|
||||
@@ -192,11 +194,31 @@ class DataSourceNotionApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = args['notion_info_list']
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'])
|
||||
return response, 200
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_restful
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
@@ -16,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
@@ -179,9 +179,9 @@ class DatasetApi(Resource):
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
@@ -259,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
@@ -269,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
@@ -279,37 +280,45 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
notion_info_list = args['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@@ -509,4 +518,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@@ -34,6 +32,7 @@ from core.indexing_runner import IndexingRunner
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import (
|
||||
@@ -71,7 +70,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
@@ -97,7 +96,7 @@ class GetProcessRuleApi(Resource):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get('document_id')
|
||||
|
||||
|
||||
# get default rules
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
@@ -296,8 +295,8 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -364,16 +363,22 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
if not file:
|
||||
raise NotFound('File not found.')
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -404,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@@ -426,42 +432,49 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if dataset.data_source_type == 'upload_file':
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(info_list)
|
||||
).all()
|
||||
if document.data_source_type == 'upload_file':
|
||||
file_id = data_source_info['upload_file_id']
|
||||
file_detail = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == 'notion_import':
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict,
|
||||
None, 'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
@@ -143,8 +142,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -234,8 +233,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
@@ -286,8 +285,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
|
||||
@@ -11,7 +11,7 @@ from controllers.console.datasets.error import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService
|
||||
@@ -39,6 +39,7 @@ class FileApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check(resource='documents')
|
||||
def post(self):
|
||||
|
||||
# get file from request
|
||||
|
||||
@@ -76,8 +76,8 @@ class HitTestingApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
@@ -86,6 +85,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Generator, Union
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
@@ -78,7 +77,7 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from flask import current_app, request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||
@@ -12,6 +11,7 @@ from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
@@ -72,6 +72,13 @@ class MemberInviteEmailApi(Resource):
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
|
||||
})
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/signin'
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
invitation_results.append({
|
||||
'status': 'failed',
|
||||
|
||||
@@ -82,6 +82,30 @@ class ToolBuiltinProviderIconApi(Resource):
|
||||
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
|
||||
|
||||
class ToolModelProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
|
||||
|
||||
class ToolModelProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.list_model_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@@ -259,6 +283,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
@@ -268,6 +293,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
|
||||
return ToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args['provider_name'] if args['provider_name'] else '',
|
||||
args['tool_name'],
|
||||
args['credentials'],
|
||||
args['parameters'],
|
||||
@@ -281,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
||||
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
|
||||
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
|
||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
@@ -57,6 +56,7 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
|
||||
if resource == 'members' and 0 < members.limit <= members.size:
|
||||
@@ -65,6 +65,13 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
abort(403, error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||
source = request.args.get('source')
|
||||
if source == 'datasets':
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
elif resource == 'workspace_custom' and not features.can_replace_logo:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
||||
|
||||
@@ -41,7 +41,7 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
raise NotFound(f'webapp logo is not found')
|
||||
raise NotFound('webapp logo is not found')
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
|
||||
@@ -32,7 +32,7 @@ class ToolFilePreviewApi(Resource):
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise NotFound(f'file is not found')
|
||||
raise NotFound('file is not found')
|
||||
|
||||
generator, mimetype = result
|
||||
except Exception:
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def create_or_update_end_user_for_user_id(app_model, user_id):
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import fields, marshal_with, Resource
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
@@ -43,8 +42,9 @@ class AppParameterApi(AppApiResource):
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
@@ -65,8 +65,9 @@ class AppParameterApi(AppApiResource):
|
||||
}
|
||||
}
|
||||
|
||||
class AppMetaApi(AppApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
class AppMetaApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
@@ -78,7 +79,7 @@ class AppMetaApi(AppApiResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppModelConfig
|
||||
from models.model import App, AppModelConfig, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -30,8 +30,9 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
|
||||
class AudioApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class AudioApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -85,7 +86,8 @@ class TextApi(AppApiResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
end_user=end_user,
|
||||
voice=args['voice'] if args['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@@ -18,17 +18,19 @@ from controllers.service_api.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
||||
class CompletionApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class CompletionApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
@@ -37,16 +39,12 @@ class CompletionApi(AppApiResource):
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
@@ -81,29 +79,20 @@ class CompletionApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class CompletionStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
class CompletionStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ChatApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class ChatApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -113,7 +102,6 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
@@ -121,9 +109,6 @@ class ChatApi(AppApiResource):
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
@@ -156,22 +141,12 @@ class ChatApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class ChatStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
class ChatStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
@@ -182,8 +157,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,53 +1,44 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class ConversationApi(AppApiResource):
|
||||
class ConversationApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
class ConversationDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
user = request.get_json().get('user')
|
||||
|
||||
if end_user is None and user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
@@ -55,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
class ConversationRenameApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -66,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,30 +1,27 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.file_fields import file_fields
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(AppApiResource):
|
||||
class FileApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
|
||||
file = request.files['file']
|
||||
user_args = request.form.get('user')
|
||||
|
||||
if end_user is None and user_args is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from extensions.ext_database import db
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import EndUser, Message
|
||||
from models.model import App, EndUser
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
class MessageListApi(Resource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
@@ -71,8 +68,9 @@ class MessageListApi(AppApiResource):
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -80,12 +78,8 @@ class MessageListApi(AppApiResource):
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, end_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
@@ -95,18 +89,15 @@ class MessageListApi(AppApiResource):
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(AppApiResource):
|
||||
def post(self, app_model, end_user, message_id):
|
||||
class MessageFeedbackApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
@@ -115,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageSuggestedApi(AppApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
class MessageSuggestedApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
try:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
).first()
|
||||
|
||||
if end_user is None and message.from_end_user_id is not None:
|
||||
user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.id == message.from_end_user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
else:
|
||||
user = end_user
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
check_enabled=False
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -29,6 +28,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
@cloud_edition_billing_resource_check('documents', 'dataset')
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -154,6 +154,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
class DocumentAddByFileApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
@cloud_edition_billing_resource_check('documents', 'dataset')
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
|
||||
@@ -46,8 +46,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
@@ -90,8 +90,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -182,8 +182,8 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
@@ -200,8 +200,8 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
SegmentService.segment_create_args_validate(args['segments'], document)
|
||||
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
|
||||
@@ -1,23 +1,40 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import ApiToken, App
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def validate_app_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
class WhereisUserArg(Enum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
JSON = 'json'
|
||||
FORM = 'form'
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
fetch_from: WhereisUserArg
|
||||
required: bool = False
|
||||
|
||||
|
||||
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('app')
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
@@ -30,16 +47,35 @@ def validate_app_token(view=None):
|
||||
if not app_model.enable_api:
|
||||
raise NotFound()
|
||||
|
||||
return view(app_model, None, *args, **kwargs)
|
||||
return decorated
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
if view:
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get('user')
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str,
|
||||
api_token_type: str,
|
||||
@@ -53,6 +89,7 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == 'members' and 0 < members.limit <= members.size:
|
||||
raise Unauthorized(error_msg)
|
||||
@@ -60,6 +97,8 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Unauthorized(error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -129,8 +168,33 @@ def validate_and_get_api_token(scope=None):
|
||||
return api_token
|
||||
|
||||
|
||||
class AppApiResource(Resource):
|
||||
method_decorators = [validate_app_token]
|
||||
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = 'DEFAULT-USER'
|
||||
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
@@ -77,7 +76,7 @@ class AppMeta(WebApiResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
@@ -69,17 +68,23 @@ class AudioApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.text_to_speech_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@@ -106,7 +111,7 @@ class TextApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
@@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
@@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
from typing import List, cast
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param model_config:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return 0
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
class ExceededLLMTokensLimitError(Exception):
|
||||
pass
|
||||
@@ -1,360 +0,0 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
model_config: ModelConfigEntity
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = 0
|
||||
for parameter_rule in self.model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
|
||||
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
self.model_config.parameters['max_tokens'] = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.model_config.parameters['max_tokens'] = original_max_tokens
|
||||
|
||||
return True if result.message.tool_calls else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_config.provider == 'azure_openai':
|
||||
model = model_config.model
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_config.credentials.get("base_model_name")
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
except KeyError:
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
@@ -1,305 +0,0 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
OutputParserException,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observatons
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||
|
||||
messages = []
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
||||
|
||||
self.moving_summary_index = len(intermediate_steps)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
if 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import (
|
||||
@@ -83,8 +84,8 @@ class AppRunner:
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: List[PromptMessage]):
|
||||
def recalc_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: list[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
@@ -126,7 +127,7 @@ class AppRunner:
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
@@ -295,7 +296,7 @@ class AppRunner:
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.moderation.base import ModerationException
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,7 +37,7 @@ class AssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
raise ValueError("App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
@@ -201,6 +195,10 @@ class AssistantApplicationRunner(AppRunner):
|
||||
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
db.session.refresh(conversation)
|
||||
db.session.refresh(message)
|
||||
db.session.close()
|
||||
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
||||
@@ -290,38 +288,6 @@ class AssistantApplicationRunner(AppRunner):
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
@@ -35,7 +35,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
raise ValueError("App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
@@ -181,7 +181,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
@@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner):
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Generator, Optional, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -88,6 +89,10 @@ class GenerateTaskPipeline:
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
db.session.refresh(self._conversation)
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
|
||||
if stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
@@ -118,7 +123,7 @@ class GenerateTaskPipeline:
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@@ -174,7 +179,7 @@ class GenerateTaskPipeline:
|
||||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'answer': self._task_state.llm_result.message.content,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
@@ -201,7 +206,7 @@ class GenerateTaskPipeline:
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@@ -302,6 +307,7 @@ class GenerateTaskPipeline:
|
||||
.first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
@@ -329,6 +335,8 @@ class GenerateTaskPipeline:
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
@@ -353,7 +361,7 @@ class GenerateTaskPipeline:
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
||||
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
|
||||
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
@@ -412,6 +420,7 @@ class GenerateTaskPipeline:
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel
|
||||
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
|
||||
@@ -2,7 +2,8 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Generator, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -27,6 +28,7 @@ from core.entities.application_entities import (
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
SensitiveWordAvoidanceEntity,
|
||||
TextToSpeechEntity,
|
||||
)
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@@ -199,7 +201,7 @@ class ApplicationManager:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.remove()
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
@@ -231,8 +233,6 @@ class ApplicationManager:
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||
-> AppOrchestrationConfigEntity:
|
||||
@@ -571,7 +571,11 @@ class ApplicationManager:
|
||||
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
|
||||
properties['text_to_speech'] = True
|
||||
properties['text_to_speech'] = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
)
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
@@ -585,7 +589,7 @@ class ApplicationManager:
|
||||
return AppOrchestrationConfigEntity(**properties)
|
||||
|
||||
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||
-> Tuple[Conversation, Message]:
|
||||
-> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -645,6 +649,7 @@ class ApplicationManager:
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
@@ -683,6 +688,7 @@ class ApplicationManager:
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db.session.refresh(message)
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any, Generator
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
@@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._message_agent_thought = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
def agent_loops(self) -> list[AgentLoop]:
|
||||
return self._agent_loops
|
||||
|
||||
def clear_agent_loops(self) -> None:
|
||||
@@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
@@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_inputs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
@@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_outputs: str,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetQuery, DocumentSegment
|
||||
from models.model import DatasetRetrieverResource
|
||||
@@ -40,22 +38,26 @@ class DatasetIndexToolCallbackHandler:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
def on_tool_end(self, documents: list[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
doc_id = document.metadata['doc_id']
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
|
||||
# add hit count to document segment
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).update(
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: List):
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
@@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
print_text("\n[on_chat_model_start]\n", color='blue')
|
||||
@@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
print_text(str(sub_message) + "\n", color='blue')
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print_text("\n[on_llm_start]\n", color='blue')
|
||||
@@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
chain_type = serialized['id'][-1]
|
||||
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
||||
|
||||
@@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from langchain.document_loaders import Docx2txtLoader, TextLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
|
||||
return cls.load_from_file(file_path, return_text)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
@@ -1,55 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.document_loaders import PyPDFium2Loader
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfLoader(BaseLoader):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
upload_file: Optional[UploadFile] = None
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._upload_file = upload_file
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._upload_file:
|
||||
if self._upload_file.hash:
|
||||
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
|
||||
+ self._upload_file.hash + '.0625.plaintext'
|
||||
try:
|
||||
text = storage.load(plaintext_file_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = PyPDFium2Loader(file_path=self._file_path).load()
|
||||
text_list = []
|
||||
for document in documents:
|
||||
text_list.append(document.page_content)
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return documents
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Any, Dict, Optional, Sequence, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from langchain.schema import Document
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
@@ -22,10 +23,10 @@ class DatasetDocumentStore:
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"dataset_id": self._dataset.id,
|
||||
@@ -40,7 +41,7 @@ class DatasetDocumentStore:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, Document]:
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id
|
||||
).all()
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import base64
|
||||
import logging
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
@@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
text_embeddings = []
|
||||
try:
|
||||
@@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
8
api/core/entities/agent_entities.py
Normal file
8
api/core/entities/agent_entities.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PlanningStrategy(Enum):
|
||||
ROUTER = 'router'
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
@@ -42,6 +42,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Completion Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class RolePrefixEntity(BaseModel):
|
||||
"""
|
||||
Role Prefix Entity.
|
||||
@@ -57,6 +58,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class PromptType(Enum):
|
||||
"""
|
||||
Prompt Type.
|
||||
@@ -97,6 +99,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
|
||||
class RetrieveStrategy(Enum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
@@ -143,6 +146,15 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
class FileUploadEntity(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
@@ -159,6 +171,7 @@ class AgentToolEntity(BaseModel):
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
@@ -166,6 +179,7 @@ class AgentPromptEntity(BaseModel):
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
|
||||
class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Agent First Prompt Entity.
|
||||
@@ -182,12 +196,14 @@ class AgentScratchpadUnit(BaseModel):
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
@@ -202,6 +218,7 @@ class AgentEntity(BaseModel):
|
||||
tools: list[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AppOrchestrationConfigEntity(BaseModel):
|
||||
"""
|
||||
App Orchestration Config Entity.
|
||||
@@ -219,7 +236,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: bool = False
|
||||
text_to_speech: dict = {}
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class ImagePromptMessageFile(PromptMessageFile):
|
||||
|
||||
|
||||
class LCHumanMessageWithFiles(HumanMessage):
|
||||
# content: Union[str, List[Union[str, Dict]]]
|
||||
# content: Union[str, list[Union[str, Dict]]]
|
||||
content: str
|
||||
files: list[PromptMessageFile]
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return None
|
||||
|
||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||
-> Tuple[ProviderModel, dict]:
|
||||
-> tuple[ProviderModel, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel):
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
tenant_id: str
|
||||
configurations: Dict[str, ProviderConfiguration] = {}
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
@@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
return all_models
|
||||
|
||||
def to_list(self) -> List[ProviderConfiguration]:
|
||||
def to_list(self) -> list[ProviderConfiguration]:
|
||||
"""
|
||||
Convert to list.
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class Extensible:
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, 'r', encoding='utf-8') as f:
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
@@ -93,7 +93,7 @@ class Extensible:
|
||||
json_path = os.path.join(subdir_path, 'schema.json')
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
with open(json_path, encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions[extension_name] = ModuleExtension(
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.entities.application_entities import (
|
||||
AgentEntity,
|
||||
AppOrchestrationConfigEntity,
|
||||
InvokeFrom,
|
||||
ModelConfigEntity,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRunnerFeature:
|
||||
def __init__(self, tenant_id: str,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
model_config: ModelConfigEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
agent_llm_callback: AgentLLMCallback,
|
||||
callback: AgentLoopGatherCallbackHandler,
|
||||
memory: Optional[TokenBufferMemory] = None,) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.app_orchestration_config = app_orchestration_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.agent_llm_callback = agent_llm_callback
|
||||
self.callback = callback
|
||||
self.memory = memory
|
||||
|
||||
def run(self, query: str,
|
||||
invoke_from: InvokeFrom) -> Optional[str]:
|
||||
"""
|
||||
Retrieve agent loop result.
|
||||
:param query: query
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
provider = self.config.provider
|
||||
model = self.config.model
|
||||
tool_configs = self.config.tools
|
||||
|
||||
# check model is support tool calling
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model,
|
||||
credentials=self.model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.FUNCTION_CALL
|
||||
|
||||
tools = self.to_tools(
|
||||
tool_configs=tool_configs,
|
||||
invoke_from=invoke_from,
|
||||
callbacks=[self.callback, DifyStdOutCallbackHandler()],
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
return None
|
||||
|
||||
agent_configuration = AgentConfiguration(
|
||||
strategy=planning_strategy,
|
||||
model_config=self.model_config,
|
||||
tools=tools,
|
||||
memory=self.memory,
|
||||
max_iterations=10,
|
||||
max_execution_time=400.0,
|
||||
early_stopping_method="generate",
|
||||
agent_llm_callback=self.agent_llm_callback,
|
||||
callbacks=[self.callback, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_executor = AgentExecutor(agent_configuration)
|
||||
|
||||
try:
|
||||
# check if should use agent
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if not should_use_agent:
|
||||
return None
|
||||
|
||||
result = agent_executor.run(query)
|
||||
return result.output
|
||||
except Exception as ex:
|
||||
logger.exception("agent_executor run failed")
|
||||
return None
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) \
|
||||
-> Optional[BaseTool]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
"""
|
||||
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=self.queue_manager,
|
||||
app_id=self.message.app_id,
|
||||
message_id=self.message.id,
|
||||
user_id=self.user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
return None
|
||||
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
|
||||
# get score threshold
|
||||
score_threshold = None
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=show_retrieve_source,
|
||||
retriever_from=invoke_from.to_source()
|
||||
)
|
||||
|
||||
return tool
|
||||
@@ -1,13 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
|
||||
embedding_provider_name = collection_binding_detail.provider_name
|
||||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_record.tenant_id,
|
||||
provider=embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embedding_model_name
|
||||
)
|
||||
|
||||
# get embedding model
|
||||
embeddings = CacheEmbedding(model_instance)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings,
|
||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||
)
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
|
||||
documents = vector_index.search(
|
||||
documents = vector.search_by_vector(
|
||||
query=query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
top_k=1,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
@@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -50,7 +58,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
@@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = prompt_messages
|
||||
self.history_prompt_messages = self.organize_agent_history(
|
||||
prompt_messages=prompt_messages or []
|
||||
)
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
@@ -104,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
db.session.close()
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
@@ -122,7 +133,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return app_orchestration_config
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
@@ -134,19 +145,19 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += f"image has been created and sent to user already, you should tell user to check it now."
|
||||
result += "image has been created and sent to user already, you should tell user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
|
||||
tenant_id=self.application_generate_entity.tenant_id,
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_tool=tool,
|
||||
agent_callback=self.agent_callback
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
@@ -161,33 +172,11 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
}
|
||||
)
|
||||
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_entity.parameters or []
|
||||
user_parameters = tool_entity.get_runtime_parameters() or []
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
@@ -203,59 +192,16 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
@@ -295,6 +241,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
@@ -310,22 +259,21 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
|
||||
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
@@ -356,7 +304,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return result
|
||||
|
||||
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
|
||||
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
@@ -394,17 +342,20 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
created_by=self.user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: List[str]
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
@@ -437,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
db.session.add(thought)
|
||||
db.session.commit()
|
||||
db.session.refresh(thought)
|
||||
db.session.close()
|
||||
|
||||
self.agent_thought_count += 1
|
||||
|
||||
@@ -449,11 +402,15 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
thought: str,
|
||||
observation: str,
|
||||
answer: str,
|
||||
messages_ids: List[str],
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.id == agent_thought.id
|
||||
).first()
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
||||
@@ -504,19 +461,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages
|
||||
"""
|
||||
if self.history_prompt_messages is None:
|
||||
self.history_prompt_messages = db.session.query(PromptMessage).filter(
|
||||
PromptMessage.message_id == self.message.id,
|
||||
).order_by(PromptMessage.position.asc()).all()
|
||||
|
||||
return self.history_prompt_messages
|
||||
db.session.close()
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
|
||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message into agent thought
|
||||
"""
|
||||
@@ -587,6 +534,69 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
).first()
|
||||
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
"""
|
||||
result = []
|
||||
# check if there is a system message in the beginning of the conversation
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
result.append(prompt_messages[0])
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
|
||||
for message in messages:
|
||||
result.append(UserPromptMessage(content=message.query))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=agent_thought.observation,
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
if message.answer:
|
||||
result.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Generator, List, Literal, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
@@ -11,6 +12,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -26,10 +28,13 @@ from models.model import Conversation, Message
|
||||
|
||||
|
||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
@@ -37,12 +42,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
self._repack_app_orchestration_config(app_orchestration_config)
|
||||
|
||||
agent_scratchpad: List[AgentScratchpadUnit] = []
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
# TODO: stop words
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
if app_orchestration_config.model_config.provider not in self._ignore_observation_providers:
|
||||
app_orchestration_config.model_config.stop.append('Observation')
|
||||
|
||||
# override inputs
|
||||
@@ -56,7 +60,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
@@ -83,7 +87,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
@@ -127,64 +131,99 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
input=query
|
||||
)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# recalc llm max tokens
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
llm_result: LLMResult = model_instance.invoke_llm(
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=[],
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=False,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not llm_result:
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
# get scratchpad
|
||||
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if llm_result.usage:
|
||||
increase_usage(llm_usage, llm_result.usage)
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = self._handle_stream_react(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, dict):
|
||||
scratchpad.agent_response += json.dumps(chunk)
|
||||
try:
|
||||
if scratchpad.action:
|
||||
raise Exception("")
|
||||
scratchpad.action_str = json.dumps(chunk)
|
||||
scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=chunk['action'],
|
||||
action_input=chunk['action_input']
|
||||
)
|
||||
except:
|
||||
scratchpad.thought += json.dumps(chunk)
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=json.dumps(chunk)
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
else:
|
||||
scratchpad.agent_response += chunk
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=llm_result.message.content,
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_result.usage)
|
||||
llm_usage=usage_dict['usage'])
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# publish agent thought if it's not empty and there is a action
|
||||
if scratchpad.thought and scratchpad.action:
|
||||
# check if final answer
|
||||
if not scratchpad.action.action_name.lower() == "final answer":
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=scratchpad.thought
|
||||
),
|
||||
usage=llm_result.usage,
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = scratchpad.agent_response or ''
|
||||
@@ -218,9 +257,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
tool_response = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
|
||||
tool_parameters=tool_call_args
|
||||
)
|
||||
# transform tool response to llm friendly response
|
||||
tool_response = self.transform_tool_invoke_messages(tool_response)
|
||||
@@ -238,7 +283,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Please check your tool provider credentials"
|
||||
error_response = "Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
@@ -259,7 +304,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
scratchpad.agent_response = llm_result.message.content
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
@@ -268,7 +312,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
tool_input=tool_call_args,
|
||||
thought=None,
|
||||
observation=observation,
|
||||
answer=llm_result.message.content,
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
@@ -315,6 +359,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
|
||||
-> Generator[Union[str, dict], None, None]:
|
||||
def parse_json(json_str):
|
||||
try:
|
||||
return json.loads(json_str.strip())
|
||||
except:
|
||||
return json_str
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
yield parse_json(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
for response in llm_response:
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
|
||||
# stream
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
if delta == '`':
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
if json_quote_count == 0:
|
||||
in_json = False
|
||||
got_json = True
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if in_json:
|
||||
json_cache += delta
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_json(json_cache)
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
|
||||
index += steps
|
||||
|
||||
if code_block_cache:
|
||||
yield code_block_cache
|
||||
|
||||
if json_cache:
|
||||
yield parse_json(json_cache)
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
@@ -326,122 +461,40 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
||||
|
||||
def _init_agent_scratchpad(self,
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
messages: list[PromptMessage]
|
||||
) -> list[AgentScratchpadUnit]:
|
||||
"""
|
||||
extract response from llm response
|
||||
init agent scratchpad
|
||||
"""
|
||||
def extra_quotes() -> AgentScratchpadUnit:
|
||||
agent_response = content
|
||||
# try to extract all quotes
|
||||
pattern = re.compile(r'```(.*?)```', re.DOTALL)
|
||||
quotes = pattern.findall(content)
|
||||
|
||||
# try to extract action from end to start
|
||||
for i in range(len(quotes) - 1, 0, -1):
|
||||
"""
|
||||
1. use json load to parse action
|
||||
2. use plain text `Action: xxx` to parse action
|
||||
"""
|
||||
try:
|
||||
action = json.loads(quotes[i].replace('```', ''))
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
for message in messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
action_str='',
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments)
|
||||
)
|
||||
except:
|
||||
# try to parse action from plain text
|
||||
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
|
||||
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
|
||||
# delete action from agent response
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
# remove extra quotes
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name[0],
|
||||
action_input=action_input[0],
|
||||
)
|
||||
)
|
||||
|
||||
def extra_json():
|
||||
agent_response = content
|
||||
# try to extract all json
|
||||
structures, pair_match_stack = [], []
|
||||
started_at, end_at = 0, 0
|
||||
for i in range(len(content)):
|
||||
if content[i] == '{':
|
||||
pair_match_stack.append(i)
|
||||
if len(pair_match_stack) == 1:
|
||||
started_at = i
|
||||
elif content[i] == '}':
|
||||
begin = pair_match_stack.pop()
|
||||
if not pair_match_stack:
|
||||
end_at = i + 1
|
||||
structures.append((content[begin:i+1], (started_at, end_at)))
|
||||
|
||||
# handle the last character
|
||||
if pair_match_stack:
|
||||
end_at = len(content)
|
||||
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
|
||||
|
||||
for i in range(len(structures), 0, -1):
|
||||
try:
|
||||
json_content, (started_at, end_at) = structures[i - 1]
|
||||
action = json.loads(json_content)
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
# delete json content from agent response
|
||||
agent_thought = agent_response[:started_at] + agent_response[end_at:]
|
||||
# remove extra quotes like ```(json)*\n\n```
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input is not None:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=json_content,
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad = extra_quotes()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
agent_scratchpad = extra_json()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=content,
|
||||
action_str='',
|
||||
action=None
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad.append(current_scratchpad)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
current_scratchpad.observation = message.content
|
||||
|
||||
return agent_scratchpad
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
):
|
||||
@@ -473,7 +526,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
next_iteration = agent_prompt_message.next_iteration
|
||||
|
||||
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
|
||||
raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
|
||||
raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
|
||||
|
||||
# check instruction, tools, and tool_names slots
|
||||
if not first_prompt.find("{{instruction}}") >= 0:
|
||||
@@ -493,7 +546,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
@@ -501,18 +554,19 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
|
||||
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
|
||||
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
|
||||
|
||||
return result
|
||||
|
||||
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
agent_scratchpad: List[AgentScratchpadUnit],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> List[PromptMessage]:
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
@@ -555,35 +609,45 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# organize prompt messages
|
||||
if mode == "chat":
|
||||
# override system message
|
||||
overrided = False
|
||||
overridden = False
|
||||
prompt_messages = prompt_messages.copy()
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message.content = system_message
|
||||
overrided = True
|
||||
overridden = True
|
||||
break
|
||||
|
||||
# convert tool prompt messages to user prompt messages
|
||||
for idx, prompt_message in enumerate(prompt_messages):
|
||||
if isinstance(prompt_message, ToolPromptMessage):
|
||||
prompt_messages[idx] = UserPromptMessage(
|
||||
content=prompt_message.content
|
||||
)
|
||||
|
||||
if not overrided:
|
||||
if not overridden:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=system_message,
|
||||
))
|
||||
|
||||
# add assistant message
|
||||
if len(agent_scratchpad) > 0:
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=(agent_scratchpad[-1].thought or '')
|
||||
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
|
||||
))
|
||||
|
||||
# add user message
|
||||
if len(agent_scratchpad) > 0:
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=(agent_scratchpad[-1].observation or ''),
|
||||
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
|
||||
))
|
||||
|
||||
self._is_first_iteration = False
|
||||
|
||||
return prompt_messages
|
||||
elif mode == "completion":
|
||||
# parse agent scratchpad
|
||||
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
|
||||
self._is_first_iteration = False
|
||||
# parse prompt messages
|
||||
return [UserPromptMessage(
|
||||
content=first_prompt.replace("{{instruction}}", instruction)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
@@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
@@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: List[MessageAgentThought] = []
|
||||
agent_thoughts: list[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
@@ -104,8 +105,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# recalc llm max tokens
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
@@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
@@ -277,7 +278,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Please check your tool provider credentials"
|
||||
error_response = "Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
@@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@@ -17,7 +17,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
@@ -38,7 +38,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
@@ -58,7 +58,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
@@ -80,7 +80,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user