mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 05:54:02 +00:00
Compare commits
102 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bd3b02be6 | ||
|
|
3cf5c1853d | ||
|
|
a4d86496e1 | ||
|
|
90bdc85f8c | ||
|
|
0828873b52 | ||
|
|
816b707a16 | ||
|
|
c9257ab4bf | ||
|
|
69ce3b3d33 | ||
|
|
c4caa7c401 | ||
|
|
dc93a292c3 | ||
|
|
174ee1b646 | ||
|
|
9b1c4f47fb | ||
|
|
582ba45c00 | ||
|
|
f1cbd55007 | ||
|
|
3a34370422 | ||
|
|
29ab244de6 | ||
|
|
920b2c2b40 | ||
|
|
ac96d192a6 | ||
|
|
07fbeb6cf0 | ||
|
|
fc64cdee64 | ||
|
|
0c0e96c55f | ||
|
|
5b953c1ef2 | ||
|
|
562ca45e07 | ||
|
|
6bbd53512e | ||
|
|
e352a8ed1b | ||
|
|
e55225e2bc | ||
|
|
3e63abd335 | ||
|
|
0620fa3094 | ||
|
|
d93288f711 | ||
|
|
ca69af7b97 | ||
|
|
952e13fef8 | ||
|
|
4be3087642 | ||
|
|
49da8a23a8 | ||
|
|
3ad943a9eb | ||
|
|
3082093293 | ||
|
|
b03bbab5ad | ||
|
|
9574730050 | ||
|
|
91ea6fe4ee | ||
|
|
769be13189 | ||
|
|
e42175241e | ||
|
|
12257b438b | ||
|
|
9ecc736c30 | ||
|
|
6c4e6bf1d6 | ||
|
|
97fe817186 | ||
|
|
52b12ed7eb | ||
|
|
d8ab4474b4 | ||
|
|
1ecbd95adf | ||
|
|
cad6e6624f | ||
|
|
3505cbe05c | ||
|
|
e15359e589 | ||
|
|
edb86f5f5a | ||
|
|
adf2651d1f | ||
|
|
5031d64e28 | ||
|
|
ae3ad59b16 | ||
|
|
e6cd7b0467 | ||
|
|
97e9f52331 | ||
|
|
25957d917a | ||
|
|
20b932da97 | ||
|
|
207080babc | ||
|
|
48bacd01cc | ||
|
|
297d0f1f30 | ||
|
|
eedbe1b770 | ||
|
|
5ff6b1da07 | ||
|
|
8b49e0ee2a | ||
|
|
e031ec9359 | ||
|
|
1bd1cd6938 | ||
|
|
81c5a21b8d | ||
|
|
61e4bbabaf | ||
|
|
4cf475680d | ||
|
|
ca4aa340f6 | ||
|
|
767d8a4b05 | ||
|
|
0b8dcaba8f | ||
|
|
af6a318aae | ||
|
|
c6e2900be7 | ||
|
|
963d9b6032 | ||
|
|
b2ee738bb1 | ||
|
|
c8ca3ff404 | ||
|
|
5d8fa2c7af | ||
|
|
58df5e5376 | ||
|
|
348ad1a624 | ||
|
|
73e17d5aa8 | ||
|
|
300d9892a5 | ||
|
|
e47b5b43b8 | ||
|
|
21c9d9e200 | ||
|
|
4f6916c4d8 | ||
|
|
8633957726 | ||
|
|
0850c953b3 | ||
|
|
23e95fd7ab | ||
|
|
e1045f01c6 | ||
|
|
e6d22fc3a0 | ||
|
|
9232244920 | ||
|
|
476eb90a90 | ||
|
|
063191889d | ||
|
|
589099a005 | ||
|
|
a0ec7de058 | ||
|
|
14a19a3da9 | ||
|
|
1b04382a9b | ||
|
|
71e5828d41 | ||
|
|
65a02f7d32 | ||
|
|
acf9174bef | ||
|
|
243ca5b1e2 | ||
|
|
f6059c377c |
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
4
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
4
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@@ -10,7 +10,9 @@ body:
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
|
||||
30
.github/pull_request_template.md
vendored
Normal file
30
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# Description
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
|
||||
Fixes # (issue)
|
||||
|
||||
## Type of Change
|
||||
|
||||
Please delete options that are not relevant.
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
|
||||
# How Has This Been Tested?
|
||||
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
|
||||
|
||||
- [ ] TODO
|
||||
|
||||
# Suggested Checklist:
|
||||
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [ ] `optional` I have made corresponding changes to the documentation
|
||||
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] `optional` New and existing unit tests pass locally with my changes
|
||||
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Run Unit Test For SDKs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [16, 18, 20]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
cache-dependency-path: 'yarn.lock'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: yarn install
|
||||
|
||||
- name: Test
|
||||
run: yarn test
|
||||
11
README.md
11
README.md
@@ -21,6 +21,17 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
|
||||
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
|
||||
</a>
|
||||
<ul align="center" style="text-decoration: none; list-style: none;">
|
||||
<li> US EST: 09:00 (9:00 AM)</li>
|
||||
<li> CET: 15:00 (3:00 PM)</li>
|
||||
<li> CST: 22:00 (10:00 PM)</li>
|
||||
</ul>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
|
||||
@@ -81,11 +81,17 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=resend
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
SMTP_SERVER=smtp.gmail.com
|
||||
SMTP_PORT=587
|
||||
SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=false
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
@@ -124,3 +130,5 @@ UNSTRUCTURED_API_URL=
|
||||
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
@@ -39,10 +38,11 @@ from extensions import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
||||
|
||||
158
api/commands.py
158
api/commands.py
@@ -6,15 +6,15 @@ import click
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account
|
||||
from models.provider import Provider, ProviderModel
|
||||
|
||||
@@ -124,14 +124,15 @@ def reset_encrypt_key_pair():
|
||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
|
||||
|
||||
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
|
||||
def create_qdrant_indexes():
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
def vdb_migrate():
|
||||
"""
|
||||
Migrate other vector database datas to Qdrant.
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start create qdrant indexes.', fg='green'))
|
||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
@@ -140,54 +141,101 @@ def create_qdrant_indexes():
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.create_qdrant_dataset(dataset)
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {
|
||||
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
try:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] == vector_type:
|
||||
continue
|
||||
if vector_type == "weaviate":
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "qdrant":
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
elif vector_type == "milvus":
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"vdb_migrate {dataset.id}")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
vector.create(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
click.echo(f"Dataset {dataset.id} create successfully.")
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
@@ -196,4 +244,4 @@ def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(vdb_migrate)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
@@ -39,7 +38,9 @@ DEFAULTS = {
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
@@ -57,6 +58,8 @@ DEFAULTS = {
|
||||
'BILLING_ENABLED': 'False',
|
||||
'CAN_REPLACE_LOGO': 'False',
|
||||
'ETL_TYPE': 'dify',
|
||||
'KEYWORD_STORE': 'jieba',
|
||||
'BATCH_UPLOAD_LIMIT': 20
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +90,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.4"
|
||||
self.CURRENT_VERSION = "0.5.7"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -183,7 +186,7 @@ class Config:
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
@@ -209,6 +212,12 @@ class Config:
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
self.RESEND_API_URL = get_env('RESEND_API_URL')
|
||||
# SMTP settings
|
||||
self.SMTP_SERVER = get_env('SMTP_SERVER')
|
||||
self.SMTP_PORT = get_env('SMTP_PORT')
|
||||
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
|
||||
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
|
||||
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
|
||||
|
||||
# ------------------------
|
||||
# Workpace Configurations.
|
||||
@@ -254,8 +263,10 @@ class Config:
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
|
||||
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
@@ -280,6 +291,8 @@ class Config:
|
||||
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
|
||||
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
|
||||
|
||||
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
|
||||
import json
|
||||
|
||||
from models.model import AppModelConfig
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
@@ -16,8 +15,10 @@ language_timezone_mapping = {
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
}
|
||||
|
||||
|
||||
def supported_language(lang):
|
||||
if lang in languages:
|
||||
return lang
|
||||
@@ -26,6 +27,7 @@ def supported_language(lang):
|
||||
.format(lang=lang))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
@@ -67,6 +69,16 @@ user_input_form_template = {
|
||||
}
|
||||
}
|
||||
],
|
||||
"ua-UK": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
demo_model_templates = {
|
||||
@@ -145,7 +157,7 @@ demo_model_templates = {
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
@@ -272,7 +284,7 @@ demo_model_templates = {
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
@@ -323,5 +335,130 @@ demo_model_templates = {
|
||||
)
|
||||
}
|
||||
],
|
||||
'uk-UA': [{
|
||||
"name": "Помічник перекладу",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
|
||||
"mode": "completion",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo-instruct",
|
||||
configs={
|
||||
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
|
||||
"prompt_variables": [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Цільова мова",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"type": "select",
|
||||
"default": "Ukrainian",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"English",
|
||||
"Japanese",
|
||||
"French",
|
||||
"Russian",
|
||||
"German",
|
||||
"Spanish",
|
||||
"Korean",
|
||||
"Italian",
|
||||
],
|
||||
},
|
||||
],
|
||||
"completion_params": {
|
||||
"max_token": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Цільова мова",
|
||||
"variable": "target_language",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
"name": "AI інтерв’юер фронтенду",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.",
|
||||
"mode": "chat",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
configs={
|
||||
"introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
"prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
"prompt_variables": [],
|
||||
"completion_params": {
|
||||
"max_token": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=None
|
||||
),
|
||||
}
|
||||
],
|
||||
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
@@ -125,19 +124,13 @@ class AppListApi(Resource):
|
||||
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
|
||||
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
|
||||
if provider_model not in available_models_names:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
if not default_model_entity:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Default System Reasoning Model available. Please configure "
|
||||
"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
model_config_dict["model"]["provider"] = default_model_entity.provider
|
||||
model_config_dict["model"]["name"] = default_model_entity.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@@ -46,7 +45,8 @@ class ChatMessageAudioApi(Resource):
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file
|
||||
file=file,
|
||||
end_user=None,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -72,7 +72,7 @@ class ChatMessageAudioApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@@ -83,10 +83,12 @@ class ChatMessageTextApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, None)
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@@ -113,9 +115,50 @@ class ChatMessageTextApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextModesApi(Resource):
|
||||
def get(self, app_id: str):
|
||||
app_model = _get_app(str(app_id))
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('language', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args['language'],
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError:
|
||||
raise AppUnavailableError("Text to audio voices language parameter loss.")
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
@@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_login
|
||||
from flask import current_app, request
|
||||
from flask_restful import Resource, reqparse
|
||||
@@ -8,7 +7,7 @@ from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
|
||||
class LoginApi(Resource):
|
||||
@@ -30,6 +29,8 @@ class LoginApi(Resource):
|
||||
except services.errors.account.AccountLoginError:
|
||||
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
# todo: return the user info
|
||||
|
||||
@@ -10,7 +10,7 @@ from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
||||
from .. import api
|
||||
|
||||
@@ -76,6 +76,8 @@ class OAuthCallback(Resource):
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.login import login_required
|
||||
@@ -173,14 +174,15 @@ class DataSourceNotionApi(Resource):
|
||||
if not data_source_binding:
|
||||
raise NotFound('Data source binding not found.')
|
||||
|
||||
loader = NotionLoader(
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
text_docs = loader.load()
|
||||
text_docs = extractor.extract()
|
||||
return {
|
||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||
}, 200
|
||||
@@ -192,11 +194,31 @@ class DataSourceNotionApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = args['notion_info_list']
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'])
|
||||
return response, 200
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_restful
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
@@ -16,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
@@ -179,9 +179,9 @@ class DatasetApi(Resource):
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
@@ -259,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
@@ -269,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
@@ -279,37 +280,45 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
notion_info_list = args['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@@ -509,4 +518,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@@ -34,6 +32,7 @@ from core.indexing_runner import IndexingRunner
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import (
|
||||
@@ -71,7 +70,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
@@ -97,7 +96,7 @@ class GetProcessRuleApi(Resource):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get('document_id')
|
||||
|
||||
|
||||
# get default rules
|
||||
mode = DocumentService.DEFAULT_RULES['mode']
|
||||
rules = DocumentService.DEFAULT_RULES['rules']
|
||||
@@ -296,8 +295,8 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -364,16 +363,22 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
if not file:
|
||||
raise NotFound('File not found.')
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -404,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@@ -426,42 +432,49 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if dataset.data_source_type == 'upload_file':
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(info_list)
|
||||
).all()
|
||||
if document.data_source_type == 'upload_file':
|
||||
file_id = data_source_info['upload_file_id']
|
||||
file_detail = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id == file_id
|
||||
).first()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == 'notion_import':
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
data_process_rule_dict, None,
|
||||
'English', dataset_id)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
data_process_rule_dict, document.doc_form,
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
info_list,
|
||||
data_process_rule_dict,
|
||||
None, 'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
@@ -143,8 +142,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -234,8 +233,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
@@ -286,8 +285,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
|
||||
@@ -76,8 +76,8 @@ class HitTestingApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
@@ -86,6 +85,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Generator, Union
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
@@ -78,7 +77,7 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from flask import current_app, request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||
@@ -12,6 +11,7 @@ from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
@@ -72,6 +72,13 @@ class MemberInviteEmailApi(Resource):
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
|
||||
})
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/signin'
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
invitation_results.append({
|
||||
'status': 'failed',
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
raise NotFound(f'webapp logo is not found')
|
||||
raise NotFound('webapp logo is not found')
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
|
||||
@@ -32,7 +32,7 @@ class ToolFilePreviewApi(Resource):
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise NotFound(f'file is not found')
|
||||
raise NotFound('file is not found')
|
||||
|
||||
generator, mimetype = result
|
||||
except Exception:
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def create_or_update_end_user_for_user_id(app_model, user_id):
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import fields, marshal_with, Resource
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
@@ -43,8 +42,9 @@ class AppParameterApi(AppApiResource):
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
@@ -65,8 +65,9 @@ class AppParameterApi(AppApiResource):
|
||||
}
|
||||
}
|
||||
|
||||
class AppMetaApi(AppApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
class AppMetaApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
@@ -78,7 +79,7 @@ class AppMetaApi(AppApiResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppModelConfig
|
||||
from models.model import App, AppModelConfig, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -30,8 +30,9 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
|
||||
class AudioApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class AudioApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -85,7 +86,8 @@ class TextApi(AppApiResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
end_user=end_user,
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@@ -18,17 +18,19 @@ from controllers.service_api.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
||||
class CompletionApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class CompletionApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
@@ -37,16 +39,12 @@ class CompletionApi(AppApiResource):
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
@@ -81,29 +79,20 @@ class CompletionApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class CompletionStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
class CompletionStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ChatApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class ChatApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -113,7 +102,6 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
@@ -121,9 +109,6 @@ class ChatApi(AppApiResource):
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
@@ -156,22 +141,12 @@ class ChatApi(AppApiResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class ChatStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
class ChatStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
@@ -182,8 +157,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,53 +1,44 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class ConversationApi(AppApiResource):
|
||||
class ConversationApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
class ConversationDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
user = request.get_json().get('user')
|
||||
|
||||
if end_user is None and user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
@@ -55,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
class ConversationRenameApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -66,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,30 +1,27 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.file_fields import file_fields
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(AppApiResource):
|
||||
class FileApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
|
||||
file = request.files['file']
|
||||
user_args = request.form.get('user')
|
||||
|
||||
if end_user is None and user_args is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from extensions.ext_database import db
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import EndUser, Message
|
||||
from models.model import App, EndUser
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
class MessageListApi(Resource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
@@ -71,8 +68,9 @@ class MessageListApi(AppApiResource):
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -80,12 +78,8 @@ class MessageListApi(AppApiResource):
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, end_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
@@ -95,18 +89,15 @@ class MessageListApi(AppApiResource):
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(AppApiResource):
|
||||
def post(self, app_model, end_user, message_id):
|
||||
class MessageFeedbackApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
@@ -115,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageSuggestedApi(AppApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
class MessageSuggestedApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
try:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
).first()
|
||||
|
||||
if end_user is None and message.from_end_user_id is not None:
|
||||
user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.id == message.from_end_user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
else:
|
||||
user = end_user
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
check_enabled=False
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -46,8 +46,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
@@ -90,8 +90,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -182,8 +182,8 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
|
||||
@@ -1,23 +1,40 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import ApiToken, App
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def validate_app_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
class WhereisUserArg(Enum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
JSON = 'json'
|
||||
FORM = 'form'
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
fetch_from: WhereisUserArg
|
||||
required: bool = False
|
||||
|
||||
|
||||
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('app')
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
@@ -30,16 +47,35 @@ def validate_app_token(view=None):
|
||||
if not app_model.enable_api:
|
||||
raise NotFound()
|
||||
|
||||
return view(app_model, None, *args, **kwargs)
|
||||
return decorated
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
if view:
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get('user')
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str,
|
||||
api_token_type: str,
|
||||
@@ -129,8 +165,33 @@ def validate_and_get_api_token(scope=None):
|
||||
return api_token
|
||||
|
||||
|
||||
class AppApiResource(Resource):
|
||||
method_decorators = [validate_app_token]
|
||||
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = 'DEFAULT-USER'
|
||||
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
@@ -77,7 +76,7 @@ class AppMeta(WebApiResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
@@ -69,17 +68,23 @@ class AudioApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.text_to_speech_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@@ -106,7 +111,7 @@ class TextApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
@@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
@@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@@ -17,7 +17,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
@@ -38,7 +38,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
@@ -58,7 +58,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
@@ -80,7 +80,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, cast
|
||||
from typing import cast
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
@@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
@@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
@@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
@@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
@@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
||||
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
@@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
@@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
@@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
@@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
@@ -180,7 +181,7 @@ Thought: {agent_scratchpad}
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@@ -213,8 +214,8 @@ Thought: {agent_scratchpad}
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
@@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
@@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
@@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
@@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
@@ -227,7 +228,7 @@ Thought: {agent_scratchpad}
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@@ -260,8 +261,8 @@ Thought: {agent_scratchpad}
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import (
|
||||
@@ -84,7 +85,7 @@ class AppRunner:
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: List[PromptMessage]):
|
||||
prompt_messages: list[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
@@ -126,7 +127,7 @@ class AppRunner:
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
@@ -295,7 +296,7 @@ class AppRunner:
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
|
||||
@@ -38,7 +38,7 @@ class AssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
raise ValueError("App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
raise ValueError("App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Generator, Optional, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -118,7 +119,7 @@ class GenerateTaskPipeline:
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@@ -174,7 +175,7 @@ class GenerateTaskPipeline:
|
||||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'answer': self._task_state.llm_result.message.content,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
@@ -201,7 +202,7 @@ class GenerateTaskPipeline:
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@@ -353,7 +354,7 @@ class GenerateTaskPipeline:
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
||||
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
|
||||
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel
|
||||
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
|
||||
@@ -2,7 +2,8 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Generator, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -27,6 +28,7 @@ from core.entities.application_entities import (
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
SensitiveWordAvoidanceEntity,
|
||||
TextToSpeechEntity,
|
||||
)
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@@ -571,7 +573,11 @@ class ApplicationManager:
|
||||
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
|
||||
properties['text_to_speech'] = True
|
||||
properties['text_to_speech'] = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
)
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
@@ -585,7 +591,7 @@ class ApplicationManager:
|
||||
return AppOrchestrationConfigEntity(**properties)
|
||||
|
||||
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||
-> Tuple[Conversation, Message]:
|
||||
-> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any, Generator
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
@@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._message_agent_thought = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
def agent_loops(self) -> list[AgentLoop]:
|
||||
return self._agent_loops
|
||||
|
||||
def clear_agent_loops(self) -> None:
|
||||
@@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
@@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_inputs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
@@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_outputs: str,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetQuery, DocumentSegment
|
||||
from models.model import DatasetRetrieverResource
|
||||
@@ -40,22 +38,26 @@ class DatasetIndexToolCallbackHandler:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
def on_tool_end(self, documents: list[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
doc_id = document.metadata['doc_id']
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
|
||||
# add hit count to document segment
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).update(
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: List):
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
@@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
print_text("\n[on_chat_model_start]\n", color='blue')
|
||||
@@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
print_text(str(sub_message) + "\n", color='blue')
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print_text("\n[on_llm_start]\n", color='blue')
|
||||
@@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
chain_type = serialized['id'][-1]
|
||||
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
||||
|
||||
@@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain import LLMChain as LCLLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
@@ -16,12 +16,12 @@ class LLMChain(LCLLMChain):
|
||||
model_config: ModelConfigEntity
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
parameters: Dict[str, Any] = {}
|
||||
parameters: dict[str, Any] = {}
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
input_list: list[dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from langchain.document_loaders import Docx2txtLoader, TextLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
|
||||
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
|
||||
return cls.load_from_file(file_path, return_text)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension in ['.docx', '.doc']:
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.document_loaders import PyPDFium2Loader
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfLoader(BaseLoader):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
upload_file: Optional[UploadFile] = None
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._upload_file = upload_file
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._upload_file:
|
||||
if self._upload_file.hash:
|
||||
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
|
||||
+ self._upload_file.hash + '.0625.plaintext'
|
||||
try:
|
||||
text = storage.load(plaintext_file_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = PyPDFium2Loader(file_path=self._file_path).load()
|
||||
text_list = []
|
||||
for document in documents:
|
||||
text_list.append(document.page_content)
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return documents
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Any, Dict, Optional, Sequence, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from langchain.schema import Document
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
@@ -22,10 +23,10 @@ class DatasetDocumentStore:
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"dataset_id": self._dataset.id,
|
||||
@@ -40,7 +41,7 @@ class DatasetDocumentStore:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, Document]:
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id
|
||||
).all()
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import base64
|
||||
import logging
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
@@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
text_embeddings = []
|
||||
try:
|
||||
@@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
@@ -42,6 +42,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Completion Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class RolePrefixEntity(BaseModel):
|
||||
"""
|
||||
Role Prefix Entity.
|
||||
@@ -57,6 +58,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class PromptType(Enum):
|
||||
"""
|
||||
Prompt Type.
|
||||
@@ -97,6 +99,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
|
||||
class RetrieveStrategy(Enum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
@@ -143,6 +146,15 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
class FileUploadEntity(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
@@ -159,6 +171,7 @@ class AgentToolEntity(BaseModel):
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
@@ -166,6 +179,7 @@ class AgentPromptEntity(BaseModel):
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
|
||||
class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Agent First Prompt Entity.
|
||||
@@ -182,12 +196,14 @@ class AgentScratchpadUnit(BaseModel):
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
@@ -202,6 +218,7 @@ class AgentEntity(BaseModel):
|
||||
tools: list[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AppOrchestrationConfigEntity(BaseModel):
|
||||
"""
|
||||
App Orchestration Config Entity.
|
||||
@@ -219,7 +236,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: bool = False
|
||||
text_to_speech: dict = {}
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class ImagePromptMessageFile(PromptMessageFile):
|
||||
|
||||
|
||||
class LCHumanMessageWithFiles(HumanMessage):
|
||||
# content: Union[str, List[Union[str, Dict]]]
|
||||
# content: Union[str, list[Union[str, Dict]]]
|
||||
content: str
|
||||
files: list[PromptMessageFile]
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return None
|
||||
|
||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||
-> Tuple[ProviderModel, dict]:
|
||||
-> tuple[ProviderModel, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel):
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
tenant_id: str
|
||||
configurations: Dict[str, ProviderConfiguration] = {}
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
@@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
return all_models
|
||||
|
||||
def to_list(self) -> List[ProviderConfiguration]:
|
||||
def to_list(self) -> list[ProviderConfiguration]:
|
||||
"""
|
||||
Convert to list.
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class Extensible:
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, 'r', encoding='utf-8') as f:
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
@@ -93,7 +93,7 @@ class Extensible:
|
||||
json_path = os.path.join(subdir_path, 'schema.json')
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
with open(json_path, encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions[extension_name] = ModuleExtension(
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
|
||||
embedding_provider_name = collection_binding_detail.provider_name
|
||||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_record.tenant_id,
|
||||
provider=embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embedding_model_name
|
||||
)
|
||||
|
||||
# get embedding model
|
||||
embeddings = CacheEmbedding(model_instance)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings,
|
||||
attributes=['doc_id', 'annotation_id', 'app_id']
|
||||
)
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
|
||||
documents = vector_index.search(
|
||||
documents = vector.search_by_vector(
|
||||
query=query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
top_k=1,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
@@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -50,7 +58,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
@@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = prompt_messages
|
||||
self.history_prompt_messages = self.organize_agent_history(
|
||||
prompt_messages=prompt_messages or []
|
||||
)
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
@@ -122,7 +132,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return app_orchestration_config
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
@@ -134,13 +144,13 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += f"image has been created and sent to user already, you should tell user to check it now."
|
||||
result += "image has been created and sent to user already, you should tell user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
@@ -325,7 +335,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
|
||||
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
@@ -356,7 +366,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return result
|
||||
|
||||
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
|
||||
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
@@ -404,7 +414,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: List[str]
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
@@ -449,7 +459,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
thought: str,
|
||||
observation: str,
|
||||
answer: str,
|
||||
messages_ids: List[str],
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
@@ -504,19 +514,8 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages
|
||||
"""
|
||||
if self.history_prompt_messages is None:
|
||||
self.history_prompt_messages = db.session.query(PromptMessage).filter(
|
||||
PromptMessage.message_id == self.message.id,
|
||||
).order_by(PromptMessage.position.asc()).all()
|
||||
|
||||
return self.history_prompt_messages
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
|
||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message into agent thought
|
||||
"""
|
||||
@@ -589,4 +588,60 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
"""
|
||||
result = []
|
||||
# check if there is a system message in the beginning of the conversation
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
result.append(prompt_messages[0])
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
|
||||
for message in messages:
|
||||
result.append(UserPromptMessage(content=message.query))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=agent_thought.observation,
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
if message.answer:
|
||||
result.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
return result
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Generator, List, Literal, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
@@ -11,6 +12,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -29,7 +31,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
@@ -37,7 +39,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
self._repack_app_orchestration_config(app_orchestration_config)
|
||||
|
||||
agent_scratchpad: List[AgentScratchpadUnit] = []
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
@@ -56,7 +59,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
@@ -83,7 +86,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
@@ -130,61 +133,95 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
llm_result: LLMResult = model_instance.invoke_llm(
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=[],
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=False,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not llm_result:
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
# get scratchpad
|
||||
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if llm_result.usage:
|
||||
increase_usage(llm_usage, llm_result.usage)
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = self._handle_stream_react(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, dict):
|
||||
scratchpad.agent_response += json.dumps(chunk)
|
||||
try:
|
||||
if scratchpad.action:
|
||||
raise Exception("")
|
||||
scratchpad.action_str = json.dumps(chunk)
|
||||
scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=chunk['action'],
|
||||
action_input=chunk['action_input']
|
||||
)
|
||||
except:
|
||||
scratchpad.thought += json.dumps(chunk)
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=json.dumps(chunk)
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
else:
|
||||
scratchpad.agent_response += chunk
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=llm_result.message.content,
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_result.usage)
|
||||
llm_usage=usage_dict['usage'])
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# publish agent thought if it's not empty and there is a action
|
||||
if scratchpad.thought and scratchpad.action:
|
||||
# check if final answer
|
||||
if not scratchpad.action.action_name.lower() == "final answer":
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=scratchpad.thought
|
||||
),
|
||||
usage=llm_result.usage,
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = scratchpad.agent_response or ''
|
||||
@@ -238,7 +275,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Please check your tool provider credentials"
|
||||
error_response = "Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
@@ -259,7 +296,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
scratchpad.agent_response = llm_result.message.content
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
@@ -268,7 +304,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
tool_input=tool_call_args,
|
||||
thought=None,
|
||||
observation=observation,
|
||||
answer=llm_result.message.content,
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
@@ -315,6 +351,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
|
||||
-> Generator[Union[str, dict], None, None]:
|
||||
def parse_json(json_str):
|
||||
try:
|
||||
return json.loads(json_str.strip())
|
||||
except:
|
||||
return json_str
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
yield parse_json(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
for response in llm_response:
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
|
||||
# stream
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
if delta == '`':
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
if json_quote_count == 0:
|
||||
in_json = False
|
||||
got_json = True
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if in_json:
|
||||
json_cache += delta
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_json(json_cache)
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
|
||||
index += steps
|
||||
|
||||
if code_block_cache:
|
||||
yield code_block_cache
|
||||
|
||||
if json_cache:
|
||||
yield parse_json(json_cache)
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
@@ -326,122 +453,40 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
||||
|
||||
def _init_agent_scratchpad(self,
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
messages: list[PromptMessage]
|
||||
) -> list[AgentScratchpadUnit]:
|
||||
"""
|
||||
extract response from llm response
|
||||
init agent scratchpad
|
||||
"""
|
||||
def extra_quotes() -> AgentScratchpadUnit:
|
||||
agent_response = content
|
||||
# try to extract all quotes
|
||||
pattern = re.compile(r'```(.*?)```', re.DOTALL)
|
||||
quotes = pattern.findall(content)
|
||||
|
||||
# try to extract action from end to start
|
||||
for i in range(len(quotes) - 1, 0, -1):
|
||||
"""
|
||||
1. use json load to parse action
|
||||
2. use plain text `Action: xxx` to parse action
|
||||
"""
|
||||
try:
|
||||
action = json.loads(quotes[i].replace('```', ''))
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
for message in messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content,
|
||||
action_str='',
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments)
|
||||
)
|
||||
except:
|
||||
# try to parse action from plain text
|
||||
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
|
||||
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
|
||||
# delete action from agent response
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
# remove extra quotes
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name[0],
|
||||
action_input=action_input[0],
|
||||
)
|
||||
)
|
||||
|
||||
def extra_json():
|
||||
agent_response = content
|
||||
# try to extract all json
|
||||
structures, pair_match_stack = [], []
|
||||
started_at, end_at = 0, 0
|
||||
for i in range(len(content)):
|
||||
if content[i] == '{':
|
||||
pair_match_stack.append(i)
|
||||
if len(pair_match_stack) == 1:
|
||||
started_at = i
|
||||
elif content[i] == '}':
|
||||
begin = pair_match_stack.pop()
|
||||
if not pair_match_stack:
|
||||
end_at = i + 1
|
||||
structures.append((content[begin:i+1], (started_at, end_at)))
|
||||
|
||||
# handle the last character
|
||||
if pair_match_stack:
|
||||
end_at = len(content)
|
||||
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
|
||||
|
||||
for i in range(len(structures), 0, -1):
|
||||
try:
|
||||
json_content, (started_at, end_at) = structures[i - 1]
|
||||
action = json.loads(json_content)
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
# delete json content from agent response
|
||||
agent_thought = agent_response[:started_at] + agent_response[end_at:]
|
||||
# remove extra quotes like ```(json)*\n\n```
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input is not None:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=json_content,
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad = extra_quotes()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
agent_scratchpad = extra_json()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=content,
|
||||
action_str='',
|
||||
action=None
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad.append(current_scratchpad)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
current_scratchpad.observation = message.content
|
||||
|
||||
return agent_scratchpad
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
):
|
||||
@@ -473,7 +518,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
next_iteration = agent_prompt_message.next_iteration
|
||||
|
||||
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
|
||||
raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
|
||||
raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
|
||||
|
||||
# check instruction, tools, and tool_names slots
|
||||
if not first_prompt.find("{{instruction}}") >= 0:
|
||||
@@ -493,7 +538,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
@@ -506,13 +551,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
return result
|
||||
|
||||
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
agent_scratchpad: List[AgentScratchpadUnit],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> List[PromptMessage]:
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
@@ -555,15 +600,22 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# organize prompt messages
|
||||
if mode == "chat":
|
||||
# override system message
|
||||
overrided = False
|
||||
overridden = False
|
||||
prompt_messages = prompt_messages.copy()
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message.content = system_message
|
||||
overrided = True
|
||||
overridden = True
|
||||
break
|
||||
|
||||
# convert tool prompt messages to user prompt messages
|
||||
for idx, prompt_message in enumerate(prompt_messages):
|
||||
if isinstance(prompt_message, ToolPromptMessage):
|
||||
prompt_messages[idx] = UserPromptMessage(
|
||||
content=prompt_message.content
|
||||
)
|
||||
|
||||
if not overrided:
|
||||
if not overridden:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=system_message,
|
||||
))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
@@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
@@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: List[MessageAgentThought] = []
|
||||
agent_thoughts: list[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
@@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
@@ -277,7 +278,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Please check your tool provider credentials"
|
||||
error_response = "Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
@@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
@@ -96,7 +96,7 @@ class DatasetRetrievalFeature:
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||
-> Optional[List[BaseTool]]:
|
||||
-> Optional[list[BaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tenant_id: tenant id
|
||||
|
||||
@@ -2,7 +2,7 @@ import concurrent
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
@@ -62,7 +62,7 @@ class ExternalDataFetchFeature:
|
||||
app_id: str,
|
||||
external_data_tool: ExternalDataVariableEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
query: str) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Query external data tool.
|
||||
:param flask_app: flask app
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from core.entities.application_entities import AppOrchestrationConfigEntity
|
||||
from core.moderation.base import ModerationAction, ModerationException
|
||||
@@ -13,7 +12,7 @@ class ModerationFeature:
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
@@ -15,8 +15,8 @@ class MessageFileParser:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> List[FileObj]:
|
||||
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileObj]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
@@ -96,7 +96,7 @@ class MessageFileParser:
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
|
||||
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
|
||||
"""
|
||||
transform message files
|
||||
|
||||
@@ -110,8 +110,8 @@ class MessageFileParser:
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
|
||||
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
|
||||
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
|
||||
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
@@ -119,7 +119,7 @@ class MessageFileParser:
|
||||
:param file_upload_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: Dict[FileType, List[FileObj]] = {
|
||||
type_file_objs: dict[FileType, list[FileObj]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
@@ -104,37 +104,17 @@ class HostingConfiguration:
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
]
|
||||
restrict_models=trial_models
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
|
||||
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
]
|
||||
restrict_models=paid_models
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
@@ -258,3 +238,11 @@ class HostingConfiguration:
|
||||
return HostedModerationConfig(
|
||||
enabled=False
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
|
||||
models_str = app_config.get(env_var)
|
||||
models_list = models_str.split(",") if models_str else []
|
||||
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
|
||||
model_name.strip()]
|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
@classmethod
|
||||
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
|
||||
if indexing_technique == "high_quality":
|
||||
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
||||
return None
|
||||
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif indexing_technique == "economy":
|
||||
return KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=10
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError('Unknown indexing technique')
|
||||
|
||||
@classmethod
|
||||
def get_default_high_quality_index(cls, dataset: Dataset):
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=' ')
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
@@ -1,305 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
class BaseVectorIndex(BaseIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, embeddings: Embeddings):
|
||||
super().__init__(dataset)
|
||||
self._embeddings = embeddings
|
||||
self._vector_store = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_index_struct(self) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store_class(self) -> type:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_full_text_index(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
|
||||
if search_type == 'similarity_score_threshold':
|
||||
score_threshold = search_kwargs.get("score_threshold")
|
||||
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||
search_kwargs['score_threshold'] = .0
|
||||
|
||||
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
|
||||
query, **search_kwargs
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, similarity in docs_with_similarity:
|
||||
doc.metadata['score'] = similarity
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
# similarity k
|
||||
# mmr k, fetch_k, lambda_mult
|
||||
# similarity_score_threshold k
|
||||
return vector_store.as_retriever(
|
||||
search_type=search_type,
|
||||
search_kwargs=search_kwargs
|
||||
).get_relevant_documents(query)
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.as_retriever(**kwargs)
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
if kwargs.get('duplicate_check', False):
|
||||
texts = self._filter_duplicate_texts(texts)
|
||||
|
||||
uuids = self._get_uuids(texts)
|
||||
vector_store.add_documents(texts, uuids=uuids)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
for node_id in ids:
|
||||
vector_store.del_text(node_id)
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
if self.dataset.collection_binding_id:
|
||||
vector_store.delete_by_group_id(group_id)
|
||||
else:
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
return False
|
||||
|
||||
def recreate_dataset(self, dataset: Dataset):
|
||||
logging.info(f"Recreating dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
origin_index_struct = self.dataset.index_struct[:]
|
||||
self.dataset.index_struct = None
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
self.dataset.index_struct = origin_index_struct
|
||||
raise e
|
||||
|
||||
dataset.index_struct = json.dumps(self.to_index_struct())
|
||||
|
||||
db.session.commit()
|
||||
|
||||
self.dataset = dataset
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def create_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"create_qdrant_dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def update_qdrant_dataset(self, dataset: Dataset):
|
||||
logging.info(f"update_qdrant_dataset {dataset.id}")
|
||||
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).first()
|
||||
|
||||
if segment:
|
||||
try:
|
||||
exist = self.text_exists(segment.index_node_id)
|
||||
if exist:
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.add_texts(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
|
||||
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
|
||||
logging.info(f"delete original collection: {dataset.id}")
|
||||
|
||||
self.delete()
|
||||
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user