mirror of
https://github.com/langgenius/dify.git
synced 2026-01-21 06:24:01 +00:00
Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
534802b761 | ||
|
|
5c258e212c | ||
|
|
6a6133c102 | ||
|
|
3c1825187a | ||
|
|
8523b34be7 | ||
|
|
65cfd4360a | ||
|
|
bbf5f42c87 | ||
|
|
3631e53ff0 | ||
|
|
f322d9bddb | ||
|
|
05ce7b9d5e | ||
|
|
72ddedfc5c | ||
|
|
36686d7425 | ||
|
|
34387ec0f1 | ||
|
|
83a6b0c626 | ||
|
|
76da66fb7e | ||
|
|
607f9eda35 | ||
|
|
f25cec265d | ||
|
|
8e66b96221 | ||
|
|
b5c1bb346c | ||
|
|
e94b323e6c | ||
|
|
bc65ee10c0 | ||
|
|
2001483659 | ||
|
|
444aba55dd | ||
|
|
3f640b1037 | ||
|
|
b07084711c | ||
|
|
fa8ab2134f | ||
|
|
1a677da792 | ||
|
|
b6d61a818e | ||
|
|
8495ffaa45 | ||
|
|
dbd1d79770 | ||
|
|
1910178199 | ||
|
|
839a6a2c8a | ||
|
|
a769edbc89 | ||
|
|
57ffecd0e5 | ||
|
|
801d135390 | ||
|
|
0428f44113 | ||
|
|
7beff3fd5a | ||
|
|
88a095e40e | ||
|
|
dd961985f0 | ||
|
|
d44b05a9e5 |
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@@ -41,6 +41,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
@@ -60,11 +62,10 @@ jobs:
|
||||
yarn run lint
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v5
|
||||
uses: super-linter/super-linter/slim@v6
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: main
|
||||
ERROR_ON_MISSING_EXEC_BIT: true
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
IGNORE_GITIGNORED_FILES: true
|
||||
|
||||
22
LICENSE
22
LICENSE
@@ -1,24 +1,26 @@
|
||||
# Dify Open Source License
|
||||
# Open Source License
|
||||
|
||||
The Dify project is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
Dify is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
|
||||
1. Dify is permitted to be used for commercialization, such as using Dify as a "backend-as-a-service" for your other applications, or delivering it to enterprises as an application development platform. However, when the following conditions are met, you must contact the producer to obtain a commercial license:
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify.AI source code to operate a multi-tenant SaaS service that is similar to the Dify.AI service edition.
|
||||
b. LOGO and copyright information: In the process of using Dify, you may not remove or modify the LOGO or copyright information in the Dify console.
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
|
||||
|
||||
Please contact business@dify.ai by email to inquire about licensing matters.
|
||||
|
||||
2. As a contributor, you should agree that your contributed code:
|
||||
2. As a contributor, you should agree that:
|
||||
|
||||
a. The producer can adjust the open-source agreement to be more strict or relaxed.
|
||||
b. Can be used for commercial purposes, such as Dify's cloud business.
|
||||
a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary.
|
||||
b. Your contributed code may be used for commercial purposes, including but not limited to its cloud business operations.
|
||||
|
||||
Apart from this, all other rights and restrictions follow the Apache License 2.0. If you need more detailed information, you can refer to the full version of Apache License 2.0.
|
||||
Apart from the specific conditions mentioned above, all other rights and restrictions follow the Apache License 2.0. Detailed information about the Apache License 2.0 can be found at http://www.apache.org/licenses/LICENSE-2.0.
|
||||
|
||||
The interactive design of this product is protected by appearance patent.
|
||||
|
||||
© 2023 LangGenius, Inc.
|
||||
© 2024 LangGenius, Inc.
|
||||
|
||||
|
||||
----------
|
||||
|
||||
@@ -82,7 +82,7 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=resend
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
@@ -131,4 +131,4 @@ UNSTRUCTURED_API_URL=
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
|
||||
146
api/commands.py
146
api/commands.py
@@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account
|
||||
from models.model import Account, App, AppAnnotationSetting, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
|
||||
|
||||
@@ -125,12 +125,121 @@ def reset_encrypt_key_pair():
|
||||
|
||||
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
def vdb_migrate():
|
||||
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in ['knowledge', 'all']:
|
||||
migrate_knowledge_vector_database()
|
||||
if scope in ['annotation', 'all']:
|
||||
migrate_annotation_vector_database()
|
||||
|
||||
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate annotation data.', fg='green'))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
# get apps info
|
||||
apps = db.session.query(App).filter(
|
||||
App.status == 'normal'
|
||||
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
page += 1
|
||||
for app in apps:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} app {app.id}. '
|
||||
+ f'{create_count} created, {skipped_count} skipped.')
|
||||
try:
|
||||
click.echo('Create app annotation index: {}'.format(app.id))
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo('App annotation setting is disabled: {}'.format(app.id))
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
|
||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
||||
).first()
|
||||
if not dataset_collection_binding:
|
||||
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
|
||||
continue
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
documents = []
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app.id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index for app: {app.id}.',
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index for app {app.id}.',
|
||||
fg='red'))
|
||||
raise e
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
|
||||
fg='green'))
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
|
||||
raise e
|
||||
click.echo(f'Successfully migrated app annotation {app.id}.')
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
|
||||
fg='green'))
|
||||
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
page = 1
|
||||
@@ -143,14 +252,19 @@ def vdb_migrate():
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
total_count = total_count + 1
|
||||
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
|
||||
+ f'{create_count} created, ${skipped_count} skipped.')
|
||||
try:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] == vector_type:
|
||||
skipped_count = skipped_count + 1
|
||||
continue
|
||||
collection_name = ''
|
||||
if vector_type == "weaviate":
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
@@ -167,7 +281,7 @@ def vdb_migrate():
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
@@ -176,7 +290,7 @@ def vdb_migrate():
|
||||
|
||||
elif vector_type == "milvus":
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
@@ -186,11 +300,17 @@ def vdb_migrate():
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"vdb_migrate {dataset.id}")
|
||||
click.echo(f"Start to migrate dataset {dataset.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
|
||||
fg='red'))
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
@@ -201,6 +321,7 @@ def vdb_migrate():
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
segments_count = 0
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
@@ -220,15 +341,22 @@ def vdb_migrate():
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
segments_count = segments_count + 1
|
||||
|
||||
if documents:
|
||||
try:
|
||||
click.echo(click.style(
|
||||
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
|
||||
fg='green'))
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
|
||||
raise e
|
||||
click.echo(f"Dataset {dataset.id} create successfully.")
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
click.echo(f'Successfully migrated dataset {dataset.id}.')
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
@@ -237,7 +365,9 @@ def vdb_migrate():
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
click.echo(
|
||||
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
|
||||
fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
|
||||
@@ -90,7 +90,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.7"
|
||||
self.CURRENT_VERSION = "0.5.8"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
||||
@@ -13,30 +13,14 @@ model_templates = {
|
||||
'status': 'normal'
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'gpt-3.5-turbo-instruct',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 512,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
}
|
||||
},
|
||||
'provider': '',
|
||||
'model_id': '',
|
||||
'configs': {},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
"completion_params": {}
|
||||
}),
|
||||
'user_input_form': json.dumps([
|
||||
{
|
||||
@@ -64,30 +48,14 @@ model_templates = {
|
||||
'status': 'normal'
|
||||
},
|
||||
'model_config': {
|
||||
'provider': 'openai',
|
||||
'model_id': 'gpt-3.5-turbo',
|
||||
'configs': {
|
||||
'prompt_template': '',
|
||||
'prompt_variables': [],
|
||||
'completion_params': {
|
||||
'max_token': 512,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
}
|
||||
},
|
||||
'provider': '',
|
||||
'model_id': '',
|
||||
'configs': {},
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
"completion_params": {}
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
@@ -129,7 +129,7 @@ class AppListApi(Resource):
|
||||
"No Default System Reasoning Model available. Please configure "
|
||||
"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = default_model_entity.provider
|
||||
model_config_dict["model"]["provider"] = default_model_entity.provider.provider
|
||||
model_config_dict["model"]["name"] = default_model_entity.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
|
||||
@@ -88,7 +88,7 @@ class ChatMessageTextApi(Resource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from controllers.console.datasets.error import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService
|
||||
@@ -39,6 +39,7 @@ class FileApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check(resource='documents')
|
||||
def post(self):
|
||||
|
||||
# get file from request
|
||||
|
||||
@@ -85,7 +85,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
|
||||
@@ -259,6 +259,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
@@ -268,6 +269,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
|
||||
return ToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args['provider_name'] if args['provider_name'] else '',
|
||||
args['tool_name'],
|
||||
args['credentials'],
|
||||
args['parameters'],
|
||||
|
||||
@@ -56,6 +56,7 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
|
||||
if resource == 'members' and 0 < members.limit <= members.size:
|
||||
@@ -64,6 +65,13 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
abort(403, error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||
source = request.args.get('source')
|
||||
if source == 'datasets':
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
elif resource == 'workspace_custom' and not features.can_replace_logo:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
||||
|
||||
@@ -87,7 +87,7 @@ class TextApi(Resource):
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=end_user,
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
voice=args['voice'] if args['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
@cloud_edition_billing_resource_check('documents', 'dataset')
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -153,6 +154,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
class DocumentAddByFileApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
@cloud_edition_billing_resource_check('vector_space', 'dataset')
|
||||
@cloud_edition_billing_resource_check('documents', 'dataset')
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
|
||||
@@ -89,6 +89,7 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == 'members' and 0 < members.limit <= members.size:
|
||||
raise Unauthorized(error_msg)
|
||||
@@ -96,6 +97,8 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Unauthorized(error_msg)
|
||||
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Unauthorized(error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class TextApi(WebApiResource):
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param model_config:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return 0
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
class ExceededLLMTokensLimitError(Exception):
|
||||
pass
|
||||
@@ -1,361 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
model_config: ModelConfigEntity
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = 0
|
||||
for parameter_rule in self.model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
|
||||
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
self.model_config.parameters['max_tokens'] = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.model_config.parameters['max_tokens'] = original_max_tokens
|
||||
|
||||
return True if result.message.tool_calls else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_config.provider == 'azure_openai':
|
||||
model = model_config.model
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_config.credentials.get("base_model_name")
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
except KeyError:
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
@@ -1,306 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
OutputParserException,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observatons
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||
|
||||
messages = []
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
||||
|
||||
self.moving_summary_index = len(intermediate_steps)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
if 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -84,7 +84,7 @@ class AppRunner:
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
def recalc_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: list[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.moderation.base import ModerationException
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
@@ -181,7 +181,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
8
api/core/entities/agent_entities.py
Normal file
8
api/core/entities/agent_entities.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PlanningStrategy(Enum):
|
||||
ROUTER = 'router'
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
@@ -1,199 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.entities.application_entities import (
|
||||
AgentEntity,
|
||||
AppOrchestrationConfigEntity,
|
||||
InvokeFrom,
|
||||
ModelConfigEntity,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRunnerFeature:
|
||||
def __init__(self, tenant_id: str,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
model_config: ModelConfigEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
agent_llm_callback: AgentLLMCallback,
|
||||
callback: AgentLoopGatherCallbackHandler,
|
||||
memory: Optional[TokenBufferMemory] = None,) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.app_orchestration_config = app_orchestration_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.agent_llm_callback = agent_llm_callback
|
||||
self.callback = callback
|
||||
self.memory = memory
|
||||
|
||||
def run(self, query: str,
|
||||
invoke_from: InvokeFrom) -> Optional[str]:
|
||||
"""
|
||||
Retrieve agent loop result.
|
||||
:param query: query
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
provider = self.config.provider
|
||||
model = self.config.model
|
||||
tool_configs = self.config.tools
|
||||
|
||||
# check model is support tool calling
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model,
|
||||
credentials=self.model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.FUNCTION_CALL
|
||||
|
||||
tools = self.to_tools(
|
||||
tool_configs=tool_configs,
|
||||
invoke_from=invoke_from,
|
||||
callbacks=[self.callback, DifyStdOutCallbackHandler()],
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
return None
|
||||
|
||||
agent_configuration = AgentConfiguration(
|
||||
strategy=planning_strategy,
|
||||
model_config=self.model_config,
|
||||
tools=tools,
|
||||
memory=self.memory,
|
||||
max_iterations=10,
|
||||
max_execution_time=400.0,
|
||||
early_stopping_method="generate",
|
||||
agent_llm_callback=self.agent_llm_callback,
|
||||
callbacks=[self.callback, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_executor = AgentExecutor(agent_configuration)
|
||||
|
||||
try:
|
||||
# check if should use agent
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if not should_use_agent:
|
||||
return None
|
||||
|
||||
result = agent_executor.run(query)
|
||||
return result.output
|
||||
except Exception as ex:
|
||||
logger.exception("agent_executor run failed")
|
||||
return None
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) \
|
||||
-> Optional[BaseTool]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
"""
|
||||
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=self.queue_manager,
|
||||
app_id=self.message.app_id,
|
||||
message_id=self.message.id,
|
||||
user_id=self.user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
return None
|
||||
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
|
||||
# get score threshold
|
||||
score_threshold = None
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=show_retrieve_source,
|
||||
retriever_from=invoke_from.to_source()
|
||||
)
|
||||
|
||||
return tool
|
||||
@@ -130,8 +130,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
input=query
|
||||
)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# recalc llm max tokens
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
|
||||
@@ -105,8 +105,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# recalc llm max tokens
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
|
||||
@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
|
||||
from core.model_manager import ModelInstance
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
@@ -12,9 +12,9 @@ from pydantic import root_validator
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
@@ -1,4 +1,3 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.helper import moderation
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class PlanningStrategy(str, enum.Enum):
|
||||
ROUTER = 'router'
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
model_config: ModelConfigEntity
|
||||
@@ -62,28 +53,7 @@ class AgentExecutor:
|
||||
self.agent = self._init_agent()
|
||||
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_model_config=self.configuration.summary_model_config
|
||||
if self.configuration.summary_model_config else None,
|
||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||
if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_model_config=self.configuration.summary_model_config
|
||||
if self.configuration.summary_model_config else None,
|
||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
if self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools
|
||||
if isinstance(t, DatasetRetrieverTool)
|
||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
@@ -2,9 +2,10 @@ from typing import Optional, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
|
||||
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider):
|
||||
|
||||
# Use `claude-instant-1` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@@ -2,8 +2,8 @@ provider: anthropic
|
||||
label:
|
||||
en_US: Anthropic
|
||||
description:
|
||||
en_US: Anthropic’s powerful models, such as Claude 2 and Claude Instant.
|
||||
zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant。
|
||||
en_US: Anthropic’s powerful models, such as Claude 3.
|
||||
zh_Hans: Anthropic 的强大模型,例如 Claude 3。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
- claude-3-opus-20240229
|
||||
- claude-3-sonnet-20240229
|
||||
- claude-2.1
|
||||
- claude-instant-1.2
|
||||
- claude-2
|
||||
- claude-instant-1
|
||||
@@ -34,3 +34,4 @@ pricing:
|
||||
output: '24.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
model: claude-3-opus-20240229
|
||||
label:
|
||||
en_US: claude-3-opus-20240229
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '15.00'
|
||||
output: '75.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,37 @@
|
||||
model: claude-3-sonnet-20240229
|
||||
label:
|
||||
en_US: claude-3-sonnet-20240229
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '3.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,35 @@
|
||||
model: claude-instant-1.2
|
||||
label:
|
||||
en_US: claude-instant-1.2
|
||||
model_type: llm
|
||||
features: [ ]
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 100000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '1.63'
|
||||
output: '5.51'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -33,3 +33,4 @@ pricing:
|
||||
output: '5.51'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@@ -1,18 +1,32 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import anthropic
|
||||
import requests
|
||||
from anthropic import Anthropic, Stream
|
||||
from anthropic.types import Completion, completion_create_params
|
||||
from anthropic.types import (
|
||||
ContentBlockDeltaEvent,
|
||||
Message,
|
||||
MessageDeltaEvent,
|
||||
MessageStartEvent,
|
||||
MessageStopEvent,
|
||||
MessageStreamEvent,
|
||||
completion_create_params,
|
||||
)
|
||||
from httpx import Timeout
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@@ -35,6 +49,7 @@ if you are not sure about the structure.
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
@@ -55,54 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm chat model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# transform model parameters from completion api of anthropic to chat api
|
||||
if 'max_tokens_to_sample' in model_parameters:
|
||||
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
|
||||
|
||||
# init model client
|
||||
client = Anthropic(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
|
||||
|
||||
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
|
||||
|
||||
if system:
|
||||
extra_model_kwargs['system'] = system
|
||||
|
||||
# chat model
|
||||
response = client.messages.create(
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if 'response_format' in model_parameters and model_parameters['response_format']:
|
||||
stop = stop or []
|
||||
self._transform_json_prompts(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
|
||||
# chat model
|
||||
self._transform_chat_json_prompts(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
response_format=model_parameters['response_format']
|
||||
)
|
||||
model_parameters.pop('response_format')
|
||||
|
||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _transform_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
if "\n```" not in stop:
|
||||
stop.append("\n```")
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{block}}", response_format)
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||
.replace("{{block}}", response_format)
|
||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||
.replace("{{block}}", response_format)
|
||||
))
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=f"```{response_format}\n"
|
||||
))
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
@@ -129,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._generate(
|
||||
self._chat_generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[
|
||||
@@ -137,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0,
|
||||
"max_tokens_to_sample": 20,
|
||||
"max_tokens": 20,
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
client = Anthropic(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
|
||||
|
||||
response = client.completions.create(
|
||||
model=model,
|
||||
prompt=self._convert_messages_to_prompt_anthropic(prompt_messages),
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
@@ -198,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.completion
|
||||
content=response.content[0].text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.input_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
response = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
return response
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Stream[MessageStreamEvent],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
Handle llm chat stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
index = -1
|
||||
full_assistant_content = ''
|
||||
return_model = None
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = None
|
||||
index = 0
|
||||
for chunk in response:
|
||||
content = chunk.completion
|
||||
if chunk.stop_reason is None and (content is None or content == ''):
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=content if content else '',
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
if chunk.stop_reason is not None:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
if isinstance(chunk, MessageStartEvent):
|
||||
return_model = chunk.message.model
|
||||
input_tokens = chunk.message.usage.input_tokens
|
||||
elif isinstance(chunk, MessageDeltaEvent):
|
||||
output_tokens = chunk.usage.output_tokens
|
||||
finish_reason = chunk.delta.stop_reason
|
||||
elif isinstance(chunk, MessageStopEvent):
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=chunk.stop_reason,
|
||||
index=index + 1,
|
||||
message=AssistantPromptMessage(
|
||||
content=''
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
else:
|
||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||
chunk_text = chunk.delta.text if chunk.delta.text else ''
|
||||
full_assistant_content += chunk_text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text
|
||||
)
|
||||
|
||||
index = chunk.index
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
index=chunk.index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -289,6 +337,80 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
||||
"""
|
||||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
system = ""
|
||||
prompt_message_dicts = []
|
||||
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
system += message.content + ("\n" if not system else "")
|
||||
else:
|
||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||
|
||||
return system, prompt_message_dicts
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||
|
||||
sub_message_dict = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
@@ -2,7 +2,7 @@ provider: jina
|
||||
label:
|
||||
en_US: Jina
|
||||
description:
|
||||
en_US: Embedding Model Supported
|
||||
en_US: Embedding and Rerank Model Supported
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
@@ -13,9 +13,10 @@ help:
|
||||
en_US: Get your API key from Jina AI
|
||||
zh_Hans: 从 Jina 获取 API Key
|
||||
url:
|
||||
en_US: https://jina.ai/embeddings/
|
||||
en_US: https://jina.ai/
|
||||
supported_model_types:
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
model: jina-reranker-v1-base-en
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
105
api/core/model_runtime/model_providers/jina/rerank/rerank.py
Normal file
105
api/core/model_runtime/model_providers/jina/rerank/rerank.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class JinaRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Jina rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
"https://api.jina.ai/v1/rerank",
|
||||
json={
|
||||
"model": model,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": top_n
|
||||
},
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results['results']:
|
||||
rerank_document = RerankDocument(
|
||||
index=result['index'],
|
||||
text=result['document']['text'],
|
||||
score=result['relevance_score'],
|
||||
)
|
||||
if score_threshold is None or result['relevance_score'] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError]
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from os.path import join
|
||||
from typing import cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from httpx import Timeout
|
||||
from openai import (
|
||||
@@ -313,10 +313,13 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
:param credentials: credentials dict
|
||||
:return: client kwargs
|
||||
"""
|
||||
if not credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] += '/'
|
||||
|
||||
client_kwargs = {
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"api_key": "1",
|
||||
"base_url": join(credentials['server_url'], 'v1'),
|
||||
"base_url": urljoin(credentials['server_url'], 'v1'),
|
||||
}
|
||||
|
||||
return client_kwargs
|
||||
|
||||
@@ -34,7 +34,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice:
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
|
||||
@@ -34,7 +34,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
|
||||
@@ -140,7 +140,8 @@ class MilvusVector(BaseVector):
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
|
||||
from pymilvus import utility
|
||||
utility.drop_collection(self._collection_name, None, using=alias)
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
utility.drop_collection(self._collection_name, None, using=alias)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
|
||||
|
||||
@@ -231,21 +231,30 @@ class QdrantVector(BaseVector):
|
||||
|
||||
def delete(self):
|
||||
from qdrant_client.http import models
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
@@ -39,7 +39,7 @@ class Vector:
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
@@ -70,7 +70,7 @@ class Vector:
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
if not self._dataset.index_struct_dict:
|
||||
index_struct_dict = {
|
||||
@@ -96,7 +96,7 @@ class Vector:
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
|
||||
@@ -70,7 +70,7 @@ class WeaviateVector(BaseVector):
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
return Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
|
||||
189
api/core/third_party/spark/spark_llm.py
vendored
189
api/core/third_party/spark/spark_llm.py
vendored
@@ -1,189 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import queue
|
||||
import ssl
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket
|
||||
|
||||
|
||||
class SparkLLMClient:
|
||||
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
||||
domain = 'spark-api.xf-yun.com'
|
||||
endpoint = 'chat'
|
||||
if api_domain:
|
||||
domain = api_domain
|
||||
if model_name == 'spark-v3':
|
||||
endpoint = 'multimodal'
|
||||
|
||||
model_api_configs = {
|
||||
'spark': {
|
||||
'version': 'v1.1',
|
||||
'chat_domain': 'general'
|
||||
},
|
||||
'spark-v2': {
|
||||
'version': 'v2.1',
|
||||
'chat_domain': 'generalv2'
|
||||
},
|
||||
'spark-v3': {
|
||||
'version': 'v3.1',
|
||||
'chat_domain': 'generalv3'
|
||||
},
|
||||
'spark-v3.5': {
|
||||
'version': 'v3.5',
|
||||
'chat_domain': 'generalv3.5'
|
||||
}
|
||||
}
|
||||
|
||||
api_version = model_api_configs[model_name]['version']
|
||||
|
||||
self.chat_domain = model_api_configs[model_name]['chat_domain']
|
||||
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
|
||||
self.app_id = app_id
|
||||
self.ws_url = self.create_url(
|
||||
urlparse(self.api_base).netloc,
|
||||
urlparse(self.api_base).path,
|
||||
self.api_base,
|
||||
api_key,
|
||||
api_secret
|
||||
)
|
||||
|
||||
self.queue = queue.Queue()
|
||||
self.blocking_message = ''
|
||||
|
||||
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
|
||||
# generate timestamp by RFC1123
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
signature_origin = "host: " + host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + path + " HTTP/1.1"
|
||||
|
||||
# encrypt using hmac-sha256
|
||||
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": host
|
||||
}
|
||||
# generate url
|
||||
url = api_base + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
def run(self, messages: list, user_id: str,
|
||||
model_kwargs: Optional[dict] = None, streaming: bool = False):
|
||||
websocket.enableTrace(False)
|
||||
ws = websocket.WebSocketApp(
|
||||
self.ws_url,
|
||||
on_message=self.on_message,
|
||||
on_error=self.on_error,
|
||||
on_close=self.on_close,
|
||||
on_open=self.on_open
|
||||
)
|
||||
ws.messages = messages
|
||||
ws.user_id = user_id
|
||||
ws.model_kwargs = model_kwargs
|
||||
ws.streaming = streaming
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def on_error(self, ws, error):
|
||||
self.queue.put({
|
||||
'status_code': error.status_code,
|
||||
'error': error.resp_body.decode('utf-8')
|
||||
})
|
||||
ws.close()
|
||||
|
||||
def on_close(self, ws, close_status_code, close_reason):
|
||||
self.queue.put({'done': True})
|
||||
|
||||
def on_open(self, ws):
|
||||
self.blocking_message = ''
|
||||
data = json.dumps(self.gen_params(
|
||||
messages=ws.messages,
|
||||
user_id=ws.user_id,
|
||||
model_kwargs=ws.model_kwargs
|
||||
))
|
||||
ws.send(data)
|
||||
|
||||
def on_message(self, ws, message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
self.queue.put({
|
||||
'status_code': 400,
|
||||
'error': f"Code: {code}, Error: {data['header']['message']}"
|
||||
})
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
if ws.streaming:
|
||||
self.queue.put({'data': content})
|
||||
else:
|
||||
self.blocking_message += content
|
||||
|
||||
if status == 2:
|
||||
if not ws.streaming:
|
||||
self.queue.put({'data': self.blocking_message})
|
||||
ws.close()
|
||||
|
||||
def gen_params(self, messages: list, user_id: str,
|
||||
model_kwargs: Optional[dict] = None) -> dict:
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": self.app_id,
|
||||
"uid": user_id
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": self.chat_domain
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": messages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if model_kwargs:
|
||||
data['parameter']['chat'].update(model_kwargs)
|
||||
|
||||
return data
|
||||
|
||||
def subscribe(self):
|
||||
while True:
|
||||
content = self.queue.get()
|
||||
if 'error' in content:
|
||||
if content['status_code'] == 401:
|
||||
raise SparkError('[Spark] The credentials you provided are incorrect. '
|
||||
'Please double-check and fill them in again.')
|
||||
elif content['status_code'] == 403:
|
||||
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
|
||||
"Please try again after obtaining the necessary permissions.")
|
||||
else:
|
||||
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
|
||||
|
||||
if 'data' not in content:
|
||||
break
|
||||
yield content
|
||||
|
||||
|
||||
class SparkError(Exception):
|
||||
pass
|
||||
@@ -1,24 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatetimeToolInput(BaseModel):
|
||||
type: str = Field(..., description="Type for current time, must be: datetime.")
|
||||
|
||||
|
||||
class DatetimeTool(BaseTool):
|
||||
"""Tool for querying current datetime."""
|
||||
name: str = "current_datetime"
|
||||
args_schema: type[BaseModel] = DatetimeToolInput
|
||||
description: str = "A tool when you want to get the current date, time, week, month or year, " \
|
||||
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
|
||||
|
||||
def _run(self, type: str) -> str:
|
||||
# get current time
|
||||
current_time = datetime.utcnow()
|
||||
return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
@@ -1,63 +0,0 @@
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
|
||||
|
||||
class BaseToolProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_validate(self, credentials: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and tool_name.
|
||||
"""
|
||||
query = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == self.tenant_id,
|
||||
ToolProvider.tool_name == self.get_provider_name().value
|
||||
)
|
||||
|
||||
if must_enabled:
|
||||
query = query.filter(ToolProvider.is_enabled == True)
|
||||
|
||||
return query.first()
|
||||
|
||||
def encrypt_token(self, token) -> str:
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
|
||||
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
if obfuscated:
|
||||
return self._obfuscated_token(token)
|
||||
|
||||
return token
|
||||
|
||||
def _obfuscated_token(self, token: str) -> str:
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
@@ -1,2 +0,0 @@
|
||||
class ToolValidateFailedError(Exception):
|
||||
description = "Tool Provider Validate failed"
|
||||
@@ -1,77 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.tool.provider.base import BaseToolProvider
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
|
||||
from models.tool import ToolProviderName
|
||||
|
||||
|
||||
class SerpAPIToolProvider(BaseToolProvider):
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
"""
|
||||
Returns the name of the provider.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return ToolProviderName.SERPAPI
|
||||
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials for SerpAPI as a dictionary.
|
||||
|
||||
:param obfuscated: obfuscate credentials if True
|
||||
:return:
|
||||
"""
|
||||
tool_provider = self.get_provider(must_enabled=True)
|
||||
if not tool_provider:
|
||||
return None
|
||||
|
||||
credentials = tool_provider.credentials
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
if credentials.get('api_key'):
|
||||
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
|
||||
|
||||
return credentials
|
||||
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials function kwargs as a dictionary.
|
||||
|
||||
:return:
|
||||
"""
|
||||
credentials = self.get_credentials()
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
return {
|
||||
'serpapi_api_key': credentials.get('api_key')
|
||||
}
|
||||
|
||||
def credentials_validate(self, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
if 'api_key' not in credentials or not credentials.get('api_key'):
|
||||
raise ToolValidateFailedError("SerpAPI api_key is required.")
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
|
||||
try:
|
||||
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
|
||||
except Exception as e:
|
||||
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
|
||||
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
"""
|
||||
Encrypts the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
|
||||
return credentials
|
||||
@@ -1,43 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.tool.provider.base import BaseToolProvider
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
|
||||
|
||||
class ToolProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self._init_provider(tenant_id, provider_name)
|
||||
|
||||
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
|
||||
if provider_name == 'serpapi':
|
||||
return SerpAPIToolProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('tool provider {} not found'.format(provider_name))
|
||||
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
Returns the credentials for Tool as a dictionary.
|
||||
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.provider.get_credentials(obfuscated)
|
||||
|
||||
def credentials_validate(self, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.credentials_validate(credentials)
|
||||
|
||||
def encrypt_credentials(self, credentials: dict):
|
||||
"""
|
||||
Encrypts the given credentials.
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return self.provider.encrypt_credentials(credentials)
|
||||
@@ -1,51 +0,0 @@
|
||||
from langchain import SerpAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OptimizedSerpAPIInput(BaseModel):
|
||||
query: str = Field(..., description="search query.")
|
||||
|
||||
|
||||
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict, num_results: int = 5) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
if "error" in res.keys():
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
|
||||
res["answer_box"] = res["answer_box"][0]
|
||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"shopping_results" in res.keys()
|
||||
and "title" in res["shopping_results"][0].keys()
|
||||
):
|
||||
toret = res["shopping_results"][:3]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
|
||||
toret = ""
|
||||
for result in res["organic_results"][:num_results]:
|
||||
if "link" in result:
|
||||
toret += "----------------\nlink: " + result["link"] + "\n"
|
||||
if "snippet" in result:
|
||||
toret += "snippet: " + result["snippet"] + "\n"
|
||||
else:
|
||||
toret = "No good search result found"
|
||||
return "search result:\n" + toret
|
||||
@@ -1,443 +0,0 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import site
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||
from langchain.chains import RefineDocumentsChain
|
||||
from langchain.chains.summarize import refine_prompts
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools.base import BaseTool
|
||||
from newspaper import Article
|
||||
from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHORS: {authors}
|
||||
PUBLISH DATE: {publish_date}
|
||||
TOP_IMAGE_URL: {top_image}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
class WebReaderToolInput(BaseModel):
|
||||
url: str = Field(..., description="URL of the website to read")
|
||||
summary: bool = Field(
|
||||
default=False,
|
||||
description="When the user's question requires extracting the summarizing content of the webpage, "
|
||||
"set it to true."
|
||||
)
|
||||
cursor: int = Field(
|
||||
default=0,
|
||||
description="Start reading from this character."
|
||||
"Use when the first response was truncated"
|
||||
"and you want to continue reading the page."
|
||||
"The value cannot exceed 24000.",
|
||||
)
|
||||
|
||||
|
||||
class WebReaderTool(BaseTool):
|
||||
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
|
||||
|
||||
name: str = "web_reader"
|
||||
args_schema: type[BaseModel] = WebReaderToolInput
|
||||
description: str = "use this to read a website. " \
|
||||
"If you can answer the question based on the information provided, " \
|
||||
"there is no need to use."
|
||||
page_contents: str = None
|
||||
url: str = None
|
||||
max_chunk_length: int = 4000
|
||||
summary_chunk_tokens: int = 4000
|
||||
summary_chunk_overlap: int = 0
|
||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||
continue_reading: bool = True
|
||||
model_config: ModelConfigEntity
|
||||
model_parameters: dict[str, Any]
|
||||
|
||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||
try:
|
||||
if not self.page_contents or self.url != url:
|
||||
page_contents = get_url(url)
|
||||
self.page_contents = page_contents
|
||||
self.url = url
|
||||
else:
|
||||
page_contents = self.page_contents
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
|
||||
if summary:
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=self.summary_chunk_tokens,
|
||||
chunk_overlap=self.summary_chunk_overlap,
|
||||
separators=self.summary_separators
|
||||
)
|
||||
|
||||
texts = character_splitter.split_text(page_contents)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
|
||||
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
|
||||
return "No content found."
|
||||
|
||||
# only use first 5 docs
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
chain = self.get_summary_chain()
|
||||
try:
|
||||
page_contents = chain.run(docs)
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
else:
|
||||
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
|
||||
|
||||
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
|
||||
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
|
||||
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
|
||||
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
|
||||
|
||||
return page_contents
|
||||
|
||||
async def _arun(self, url: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
refine_chain = LLMChain(
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.REFINE_PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name="text",
|
||||
initial_response_name="existing_answer",
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor: cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
|
||||
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if head_response.status_code != 200:
|
||||
return "URL returned status code {}.".format(head_response.status_code)
|
||||
|
||||
# check content-type
|
||||
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
|
||||
a = extract_using_readabilipy(response.text)
|
||||
|
||||
if not a['plain_text'] or not a['plain_text'].strip():
|
||||
return get_url_from_newspaper3k(url)
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a['title'],
|
||||
authors=a['byline'],
|
||||
publish_date=a['date'],
|
||||
top_image="",
|
||||
text=a['plain_text'] if a['plain_text'] else "",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_url_from_newspaper3k(url: str) -> str:
|
||||
|
||||
a = Article(url)
|
||||
a.download()
|
||||
a.parse()
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a.title,
|
||||
authors=a.authors,
|
||||
publish_date=a.publish_date,
|
||||
top_image=a.top_image,
|
||||
text=a.text,
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extract_using_readabilipy(html):
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
|
||||
f_html.write(html)
|
||||
f_html.close()
|
||||
html_path = f_html.name
|
||||
|
||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||
article_json_path = html_path + ".json"
|
||||
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
|
||||
with chdir(jsdir):
|
||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||
with open(article_json_path, encoding="utf-8") as json_file:
|
||||
input_json = json.loads(json_file.read())
|
||||
|
||||
# Deleting files after processing
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
"content": None,
|
||||
"plain_content": None,
|
||||
"plain_text": None
|
||||
}
|
||||
# Populate article fields from readability fields where present
|
||||
if input_json:
|
||||
if "title" in input_json and input_json["title"]:
|
||||
article_json["title"] = input_json["title"]
|
||||
if "byline" in input_json and input_json["byline"]:
|
||||
article_json["byline"] = input_json["byline"]
|
||||
if "date" in input_json and input_json["date"]:
|
||||
article_json["date"] = input_json["date"]
|
||||
if "content" in input_json and input_json["content"]:
|
||||
article_json["content"] = input_json["content"]
|
||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||
if "textContent" in input_json and input_json["textContent"]:
|
||||
article_json["plain_text"] = input_json["textContent"]
|
||||
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
|
||||
|
||||
return article_json
|
||||
|
||||
|
||||
def find_module_path(module_name):
|
||||
for package_path in site.getsitepackages():
|
||||
potential_path = os.path.join(package_path, module_name)
|
||||
if os.path.exists(potential_path):
|
||||
return potential_path
|
||||
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def chdir(path):
|
||||
"""Change directory in context and return to original on exit"""
|
||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
||||
original_path = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_path)
|
||||
|
||||
|
||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(paragraph_html, 'html.parser')
|
||||
# Select all lists
|
||||
list_elements = soup.find_all(['ul', 'ol'])
|
||||
# Prefix text in all list items with "* " and make lists paragraphs
|
||||
for list_element in list_elements:
|
||||
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
|
||||
list_element.string = plain_items
|
||||
list_element.name = "p"
|
||||
# Select all text blocks
|
||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
||||
# Drop empty paragraphs
|
||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
||||
return text_blocks
|
||||
|
||||
|
||||
def plain_text_leaf_node(element):
|
||||
# Extract all text, stripped of any child HTML elements and normalise it
|
||||
plain_text = normalise_text(element.get_text())
|
||||
if plain_text != "" and element.name == "li":
|
||||
plain_text = "* {}, ".format(plain_text)
|
||||
if plain_text == "":
|
||||
plain_text = None
|
||||
if "data-node-index" in element.attrs:
|
||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
||||
else:
|
||||
plain = {"text": plain_text}
|
||||
return plain
|
||||
|
||||
|
||||
def plain_content(readability_content, content_digests, node_indexes):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(readability_content, 'html.parser')
|
||||
# Make all elements plain
|
||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||
if node_indexes:
|
||||
# Add node index attributes to nodes
|
||||
elements = [add_node_indexes(element) for element in elements]
|
||||
# Replace article contents with plain elements
|
||||
soup.contents = elements
|
||||
return str(soup)
|
||||
|
||||
|
||||
def plain_elements(elements, content_digests, node_indexes):
|
||||
# Get plain content versions of all elements
|
||||
elements = [plain_element(element, content_digests, node_indexes)
|
||||
for element in elements]
|
||||
if content_digests:
|
||||
# Add content digest attribute to nodes
|
||||
elements = [add_content_digest(element) for element in elements]
|
||||
return elements
|
||||
|
||||
|
||||
def plain_element(element, content_digests, node_indexes):
|
||||
# For lists, we make each item plain text
|
||||
if is_leaf(element):
|
||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||
# 1. Get element contents as text
|
||||
plain_text = element.get_text()
|
||||
# 2. Normalise the extracted text string to a canonical representation
|
||||
plain_text = normalise_text(plain_text)
|
||||
# 3. Update element content to be plain text
|
||||
element.string = plain_text
|
||||
elif is_text(element):
|
||||
if is_non_printing(element):
|
||||
# The simplified HTML may have come from Readability.js so might
|
||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
||||
# keep the structure, but ensure that the string is empty.
|
||||
element = type(element)("")
|
||||
else:
|
||||
plain_text = element.string
|
||||
plain_text = normalise_text(plain_text)
|
||||
element = type(element)(plain_text)
|
||||
else:
|
||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
||||
return element
|
||||
|
||||
|
||||
def add_node_indexes(element, node_index="0"):
|
||||
# Can't add attributes to string types
|
||||
if is_text(element):
|
||||
return element
|
||||
# Add index to current element
|
||||
element["data-node-index"] = node_index
|
||||
# Add index to child elements
|
||||
for local_idx, child in enumerate(
|
||||
[c for c in element.contents if not is_text(c)], start=1):
|
||||
# Can't add attributes to leaf string types
|
||||
child_index = "{stem}.{local}".format(
|
||||
stem=node_index, local=local_idx)
|
||||
add_node_indexes(child, node_index=child_index)
|
||||
return element
|
||||
|
||||
|
||||
def normalise_text(text):
|
||||
"""Normalise unicode and whitespace."""
|
||||
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
|
||||
text = strip_control_characters(text)
|
||||
text = normalise_unicode(text)
|
||||
text = normalise_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def strip_control_characters(text):
|
||||
"""Strip out unicode control characters which might break the parsing."""
|
||||
# Unicode control characters
|
||||
# [Cc]: Other, Control [includes new lines]
|
||||
# [Cf]: Other, Format
|
||||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
|
||||
retained_chars = ['\t', '\n', '\r', '\f']
|
||||
|
||||
# Remove non-printing control characters
|
||||
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
|
||||
|
||||
|
||||
def normalise_unicode(text):
|
||||
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
|
||||
def normalise_whitespace(text):
|
||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||
text = regex.sub(r"\s+", " ", text)
|
||||
# Remove leading and trailing whitespace
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def is_leaf(element):
|
||||
return (element.name in ['p', 'li'])
|
||||
|
||||
|
||||
def is_text(element):
|
||||
return isinstance(element, NavigableString)
|
||||
|
||||
|
||||
def is_non_printing(element):
|
||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
||||
|
||||
|
||||
def add_content_digest(element):
|
||||
if not is_text(element):
|
||||
element["data-content-digest"] = content_digest(element)
|
||||
return element
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
if trimmed_string == "":
|
||||
digest = ""
|
||||
else:
|
||||
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
|
||||
else:
|
||||
contents = element.contents
|
||||
num_contents = len(contents)
|
||||
if num_contents == 0:
|
||||
# No hash when no child elements exist
|
||||
digest = ""
|
||||
elif num_contents == 1:
|
||||
# If single child, use digest of child
|
||||
digest = content_digest(contents[0])
|
||||
else:
|
||||
# Build content digest from the "non-empty" digests of child nodes
|
||||
digest = hashlib.sha256()
|
||||
child_digests = list(
|
||||
filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
for child in child_digests:
|
||||
digest.update(child.encode('utf-8'))
|
||||
digest = digest.hexdigest()
|
||||
return digest
|
||||
@@ -1,16 +1,20 @@
|
||||
- google
|
||||
- bing
|
||||
- duckduckgo
|
||||
- yahoo
|
||||
- wikipedia
|
||||
- arxiv
|
||||
- pubmed
|
||||
- dalle
|
||||
- azuredalle
|
||||
- stablediffusion
|
||||
- webscraper
|
||||
- youtube
|
||||
- wolframalpha
|
||||
- maths
|
||||
- github
|
||||
- chart
|
||||
- time
|
||||
- yahoo
|
||||
- stablediffusion
|
||||
- vectorizer
|
||||
- youtube
|
||||
- gaode
|
||||
- maths
|
||||
- wecom
|
||||
|
||||
@@ -55,6 +55,21 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
en_US='The api key',
|
||||
zh_Hans='api key的值'
|
||||
)
|
||||
),
|
||||
'api_key_header_prefix': ToolProviderCredentials(
|
||||
name='api_key_header_prefix',
|
||||
required=False,
|
||||
default='basic',
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
help=I18nObject(
|
||||
en_US='The prefix of the api key header',
|
||||
zh_Hans='api key header 的前缀'
|
||||
),
|
||||
options=[
|
||||
ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
|
||||
ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
|
||||
ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
|
||||
]
|
||||
)
|
||||
}
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
|
||||
1
api/core/tools/provider/builtin/arxiv/_assets/icon.svg
Normal file
1
api/core/tools/provider/builtin/arxiv/_assets/icon.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg id="logomark" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 17.732 24.269"><g id="tiny"><path d="M573.549,280.916l2.266,2.738,6.674-7.84c.353-.47.52-.717.353-1.117a1.218,1.218,0,0,0-1.061-.748h0a.953.953,0,0,0-.712.262Z" transform="translate(-566.984 -271.548)" fill="#bdb9b4"/><path d="M579.525,282.225l-10.606-10.174a1.413,1.413,0,0,0-.834-.5,1.09,1.09,0,0,0-1.027.66c-.167.4-.047.681.319,1.206l8.44,10.242h0l-6.282,7.716a1.336,1.336,0,0,0-.323,1.3,1.114,1.114,0,0,0,1.04.69A.992.992,0,0,0,571,293l8.519-7.92A1.924,1.924,0,0,0,579.525,282.225Z" transform="translate(-566.984 -271.548)" fill="#b31b1b"/><path d="M584.32,293.912l-8.525-10.275,0,0L573.53,280.9l-1.389,1.254a2.063,2.063,0,0,0,0,2.965l10.812,10.419a.925.925,0,0,0,.742.282,1.039,1.039,0,0,0,.953-.667A1.261,1.261,0,0,0,584.32,293.912Z" transform="translate(-566.984 -271.548)" fill="#bdb9b4"/></g></svg>
|
||||
|
After Width: | Height: | Size: 874 B |
20
api/core/tools/provider/builtin/arxiv/arxiv.py
Normal file
20
api/core/tools/provider/builtin/arxiv/arxiv.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.arxiv.tools.arxiv_search import ArxivSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class ArxivProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
ArxivSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"query": "John Doe",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
10
api/core/tools/provider/builtin/arxiv/arxiv.yaml
Normal file
10
api/core/tools/provider/builtin/arxiv/arxiv.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
identity:
|
||||
author: Yash Parmar
|
||||
name: arxiv
|
||||
label:
|
||||
en_US: ArXiv
|
||||
zh_Hans: ArXiv
|
||||
description:
|
||||
en_US: Access to a vast repository of scientific papers and articles in various fields of research.
|
||||
zh_Hans: 访问各个研究领域大量科学论文和文章的存储库。
|
||||
icon: icon.svg
|
||||
37
api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
Normal file
37
api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.utilities import ArxivAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class ArxivSearchInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
class ArxivSearchTool(BuiltinTool):
|
||||
"""
|
||||
A tool for searching articles on Arxiv.
|
||||
"""
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the Arxiv search tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool, including the 'query' parameter.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
|
||||
"""
|
||||
query = tool_parameters.get('query', '')
|
||||
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
arxiv = ArxivAPIWrapper()
|
||||
|
||||
response = arxiv.run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=response))
|
||||
@@ -0,0 +1,23 @@
|
||||
identity:
|
||||
name: arxiv_search
|
||||
author: Yash Parmar
|
||||
label:
|
||||
en_US: Arxiv Search
|
||||
zh_Hans: Arxiv 搜索
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name.
|
||||
zh_Hans: 一个用于从Arxiv存储库搜索科学论文和文章的工具。 输入可以是Arxiv ID或作者姓名。
|
||||
llm: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询字符串
|
||||
human_description:
|
||||
en_US: The Arxiv ID or author's name used for searching.
|
||||
zh_Hans: 用于搜索的Arxiv ID或作者姓名。
|
||||
llm_description: The Arxiv ID or author's name used for searching.
|
||||
form: llm
|
||||
@@ -16,7 +16,8 @@ class BingProvider(BuiltinToolProviderController):
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link"
|
||||
"result_type": "link",
|
||||
"enable_webpages": True,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -54,4 +54,4 @@ class GaodeRepositoriesTool(BuiltinTool):
|
||||
s.close()
|
||||
return self.create_text_message(f'No weather information for {city} was found.')
|
||||
except Exception as e:
|
||||
return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e))
|
||||
return self.create_text_message("Gaode API Key and Api Version is invalid. {}".format(e))
|
||||
|
||||
1
api/core/tools/provider/builtin/pubmed/_assets/icon.svg
Normal file
1
api/core/tools/provider/builtin/pubmed/_assets/icon.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg height="512" viewBox="0 0 448 512" width="448" xmlns="http://www.w3.org/2000/svg"><path d="m48 32c-26.5 0-48 21.5-48 48v352c0 26.5 21.5 48 48 48h352c26.5 0 48-21.5 48-48v-352c0-26.5-21.5-48-48-48zm69.56445 64s49.09165 11.12539 46.59571 94.78125c0 0 41.47034-117.171493 204.5664 1.64844 0 42.78788-.31445 172.24246-.31445 223.57031-176.89733-149.87989-207.38477-22.06836-207.38477-22.06836 0-79.8558-81.753902-70.33984-81.753902-70.33984v-212.65039s18.755175 1.4021 38.291012 11.11132zm86.14649 98.2832-24.00196 141.34961h36.5625l11.81446-81.3789h.37304l32.44727 81.3789h14.63281l33.93946-81.3789h.37304l10.31446 81.3789h36.7832l-21.40234-141.34961h-36.5625l-30.38868 75.54102-28.69531-75.54102z"/></svg>
|
||||
|
After Width: | Height: | Size: 708 B |
20
api/core/tools/provider/builtin/pubmed/pubmed.py
Normal file
20
api/core/tools/provider/builtin/pubmed/pubmed.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.pubmed.tools.pubmed_search import PubMedSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class PubMedProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
PubMedSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"query": "John Doe",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
10
api/core/tools/provider/builtin/pubmed/pubmed.yaml
Normal file
10
api/core/tools/provider/builtin/pubmed/pubmed.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
identity:
|
||||
author: Pink Banana
|
||||
name: pubmed
|
||||
label:
|
||||
en_US: PubMed
|
||||
zh_Hans: PubMed
|
||||
description:
|
||||
en_US: A search engine for biomedical literature.
|
||||
zh_Hans: 一款生物医学文献搜索引擎。
|
||||
icon: icon.svg
|
||||
@@ -0,0 +1,40 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import PubmedQueryRun
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PubMedInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
|
||||
class PubMedSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using PubMed search engine.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invoke the PubMed search tool.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool invocation.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
|
||||
"""
|
||||
query = tool_parameters.get('query', '')
|
||||
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
tool = PubmedQueryRun(args_schema=PubMedInput)
|
||||
|
||||
result = tool.run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
identity:
|
||||
name: pubmed_search
|
||||
author: Pink Banana
|
||||
label:
|
||||
en_US: PubMed Search
|
||||
zh_Hans: PubMed 搜索
|
||||
description:
|
||||
human:
|
||||
en_US: PubMed® comprises more than 35 million citations for biomedical literature from MEDLINE, life science journals, and online books. Citations may include links to full text content from PubMed Central and publisher web sites.
|
||||
zh_Hans: PubMed® 包含来自 MEDLINE、生命科学期刊和在线书籍的超过 3500 万篇生物医学文献引用。引用可能包括来自 PubMed Central 和出版商网站的全文内容链接。
|
||||
llm: Perform searches on PubMed and get results.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
human_description:
|
||||
en_US: The search query.
|
||||
zh_Hans: 搜索查询语句。
|
||||
llm_description: Key words for searching
|
||||
form: llm
|
||||
@@ -70,7 +70,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
if not base_url:
|
||||
return self.create_text_message('Please input base_url')
|
||||
|
||||
if 'model' in tool_parameters:
|
||||
if 'model' in tool_parameters and tool_parameters['model']:
|
||||
self.runtime.credentials['model'] = tool_parameters['model']
|
||||
|
||||
model = self.runtime.credentials.get('model', None)
|
||||
|
||||
BIN
api/core/tools/provider/builtin/wecom/_assets/icon.png
Normal file
BIN
api/core/tools/provider/builtin/wecom/_assets/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 257 KiB |
@@ -0,0 +1,46 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class WecomRepositoriesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
hook_key = tool_parameters.get('hook_key', '')
|
||||
if not hook_key:
|
||||
return self.create_text_message('Invalid parameter hook_key')
|
||||
|
||||
msgtype = 'text'
|
||||
api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send'
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
params = {
|
||||
'key': hook_key,
|
||||
}
|
||||
payload = {
|
||||
"msgtype": msgtype,
|
||||
"text": {
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
res = httpx.post(api_url, headers=headers, params=params, json=payload)
|
||||
if res.is_success:
|
||||
return self.create_text_message("Text message sent successfully")
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}")
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to send message to group chat bot. {}".format(e))
|
||||
@@ -0,0 +1,40 @@
|
||||
identity:
|
||||
name: wecom_group_bot
|
||||
author: Bowen Liang
|
||||
label:
|
||||
en_US: Send Group Message
|
||||
zh_Hans: 发送群消息
|
||||
pt_BR: Send Group Message
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Sending a group message on Wecom via the webhook of group bot
|
||||
zh_Hans: 通过企业微信的群机器人webhook发送群消息
|
||||
pt_BR: Sending a group message on Wecom via the webhook of group bot
|
||||
llm: A tool for sending messages to a chat group on Wecom(企业微信) .
|
||||
parameters:
|
||||
- name: hook_key
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Wecom Group bot webhook key
|
||||
zh_Hans: 群机器人webhook的key
|
||||
pt_BR: Wecom Group bot webhook key
|
||||
human_description:
|
||||
en_US: Wecom Group bot webhook key
|
||||
zh_Hans: 群机器人webhook的key
|
||||
pt_BR: Wecom Group bot webhook key
|
||||
form: form
|
||||
- name: content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: content
|
||||
zh_Hans: 消息内容
|
||||
pt_BR: content
|
||||
human_description:
|
||||
en_US: Content to sent to the group.
|
||||
zh_Hans: 群消息文本
|
||||
pt_BR: Content to sent to the group.
|
||||
llm_description: Content of the message
|
||||
form: llm
|
||||
8
api/core/tools/provider/builtin/wecom/wecom.py
Normal file
8
api/core/tools/provider/builtin/wecom/wecom.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomRepositoriesTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WecomProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
WecomRepositoriesTool()
|
||||
pass
|
||||
13
api/core/tools/provider/builtin/wecom/wecom.yaml
Normal file
13
api/core/tools/provider/builtin/wecom/wecom.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
identity:
|
||||
author: Bowen Liang
|
||||
name: wecom
|
||||
label:
|
||||
en_US: Wecom
|
||||
zh_Hans: 企业微信
|
||||
pt_BR: Wecom
|
||||
description:
|
||||
en_US: Wecom group bot
|
||||
zh_Hans: 企业微信群机器人
|
||||
pt_BR: Wecom group bot
|
||||
icon: icon.png
|
||||
credentials_for_provider:
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from json import dumps
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
@@ -62,6 +63,17 @@ class ApiTool(Tool):
|
||||
|
||||
if 'api_key_value' not in credentials:
|
||||
raise ToolProviderCredentialValidationError('Missing api_key_value')
|
||||
elif not isinstance(credentials['api_key_value'], str):
|
||||
raise ToolProviderCredentialValidationError('api_key_value must be a string')
|
||||
|
||||
if 'api_key_header_prefix' in credentials:
|
||||
api_key_header_prefix = credentials['api_key_header_prefix']
|
||||
if api_key_header_prefix == 'basic' and credentials['api_key_value']:
|
||||
credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}'
|
||||
elif api_key_header_prefix == 'bearer' and credentials['api_key_value']:
|
||||
credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
|
||||
elif api_key_header_prefix == 'custom':
|
||||
pass
|
||||
|
||||
headers[api_key_header] = credentials['api_key_value']
|
||||
|
||||
@@ -173,21 +185,7 @@ class ApiTool(Tool):
|
||||
for name, property in properties.items():
|
||||
if name in parameters:
|
||||
# convert type
|
||||
try:
|
||||
value = parameters[name]
|
||||
if property['type'] == 'integer':
|
||||
value = int(value)
|
||||
elif property['type'] == 'number':
|
||||
# check if it is a float
|
||||
if '.' in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
elif property['type'] == 'boolean':
|
||||
value = bool(value)
|
||||
body[name] = value
|
||||
except ValueError as e:
|
||||
body[name] = parameters[name]
|
||||
body[name] = self._convert_body_property_type(property, parameters[name])
|
||||
elif name in required:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
|
||||
@@ -206,6 +204,8 @@ class ApiTool(Tool):
|
||||
if 'Content-Type' in headers:
|
||||
if headers['Content-Type'] == 'application/json':
|
||||
body = dumps(body)
|
||||
elif headers['Content-Type'] == 'application/x-www-form-urlencoded':
|
||||
body = urlencode(body)
|
||||
else:
|
||||
body = body
|
||||
|
||||
@@ -217,10 +217,6 @@ class ApiTool(Tool):
|
||||
elif method == 'put':
|
||||
response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True)
|
||||
elif method == 'delete':
|
||||
"""
|
||||
request body data is unsupported for DELETE method in standard http protocol
|
||||
however, OpenAPI 3.0 supports request body data for DELETE method, so we support it here by using requests
|
||||
"""
|
||||
response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, allow_redirects=True)
|
||||
elif method == 'patch':
|
||||
response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True)
|
||||
@@ -232,6 +228,66 @@ class ApiTool(Tool):
|
||||
raise ValueError(f'Invalid http method {method}')
|
||||
|
||||
return response
|
||||
|
||||
def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10) -> Any:
|
||||
if max_recursive <= 0:
|
||||
raise Exception("Max recursion depth reached")
|
||||
for option in any_of or []:
|
||||
try:
|
||||
if 'type' in option:
|
||||
# Attempt to convert the value based on the type.
|
||||
if option['type'] == 'integer' or option['type'] == 'int':
|
||||
return int(value)
|
||||
elif option['type'] == 'number':
|
||||
if '.' in str(value):
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
elif option['type'] == 'string':
|
||||
return str(value)
|
||||
elif option['type'] == 'boolean':
|
||||
if str(value).lower() in ['true', '1']:
|
||||
return True
|
||||
elif str(value).lower() in ['false', '0']:
|
||||
return False
|
||||
else:
|
||||
continue # Not a boolean, try next option
|
||||
elif option['type'] == 'null' and not value:
|
||||
return None
|
||||
else:
|
||||
continue # Unsupported type, try next option
|
||||
elif 'anyOf' in option and isinstance(option['anyOf'], list):
|
||||
# Recursive call to handle nested anyOf
|
||||
return self._convert_body_property_any_of(property, value, option['anyOf'], max_recursive - 1)
|
||||
except ValueError:
|
||||
continue # Conversion failed, try next option
|
||||
# If no option succeeded, you might want to return the value as is or raise an error
|
||||
return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
|
||||
|
||||
def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
|
||||
try:
|
||||
if 'type' in property:
|
||||
if property['type'] == 'integer' or property['type'] == 'int':
|
||||
return int(value)
|
||||
elif property['type'] == 'number':
|
||||
# check if it is a float
|
||||
if '.' in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
elif property['type'] == 'string':
|
||||
return str(value)
|
||||
elif property['type'] == 'boolean':
|
||||
return bool(value)
|
||||
elif property['type'] == 'null':
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Invalid type {property['type']} for property {property}")
|
||||
elif 'anyOf' in property and isinstance(property['anyOf'], list):
|
||||
return self._convert_body_property_any_of(property, value, property['anyOf'])
|
||||
except ValueError as e:
|
||||
return value
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.tool import Tool
|
||||
@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler
|
||||
) -> list['DatasetRetrieverTool']:
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler
|
||||
) -> list['DatasetRetrieverTool']:
|
||||
"""
|
||||
get dataset tool
|
||||
"""
|
||||
@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
|
||||
)
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
|
||||
# convert langchain tools to Tools
|
||||
tools = []
|
||||
for langchain_tool in langchain_tools:
|
||||
@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
|
||||
llm=langchain_tool.description),
|
||||
runtime=DatasetRetrieverTool.Runtime()
|
||||
)
|
||||
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(name='query',
|
||||
label=I18nObject(en_US='', zh_Hans=''),
|
||||
human_description=I18nObject(en_US='', zh_Hans=''),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description='Query for the dataset to be used to retrieve the dataset.',
|
||||
required=True,
|
||||
default=''),
|
||||
label=I18nObject(en_US='', zh_Hans=''),
|
||||
human_description=I18nObject(en_US='', zh_Hans=''),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description='Query for the dataset to be used to retrieve the dataset.',
|
||||
required=True,
|
||||
default=''),
|
||||
]
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
|
||||
query = tool_parameters.get('query', None)
|
||||
if not query:
|
||||
return self.create_text_message(text='please input query')
|
||||
|
||||
|
||||
# invoke dataset retriever tool
|
||||
result = self.langchain_tool._run(query=query)
|
||||
|
||||
@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -146,7 +146,8 @@ class ApiBasedToolSchemaParser:
|
||||
bundles.append(ApiBasedToolBundle(
|
||||
server_url=server_url + interface['path'],
|
||||
method=interface['method'],
|
||||
summary=interface['operation']['summary'] if 'summary' in interface['operation'] else None,
|
||||
summary=interface['operation']['description'] if 'description' in interface['operation'] else
|
||||
interface['operation']['summary'] if 'summary' in interface['operation'] else None,
|
||||
operation_id=interface['operation']['operationId'],
|
||||
parameters=parameters,
|
||||
author='',
|
||||
@@ -249,12 +250,10 @@ class ApiBasedToolSchemaParser:
|
||||
if 'operationId' not in operation:
|
||||
raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.')
|
||||
|
||||
if 'summary' not in operation or len(operation['summary']) == 0:
|
||||
warning['missing_summary'] = f'No summary found in operation {method} {path}.'
|
||||
if ('summary' not in operation or len(operation['summary']) == 0) and \
|
||||
('description' not in operation or len(operation['description']) == 0):
|
||||
warning['missing_summary'] = f'No summary or description found in operation {method} {path}.'
|
||||
|
||||
if 'description' not in operation or len(operation['description']) == 0:
|
||||
warning['missing_description'] = f'No description found in operation {method} {path}.'
|
||||
|
||||
openapi['paths'][path][method] = {
|
||||
'operationId': operation['operationId'],
|
||||
'summary': operation.get('summary', ''),
|
||||
|
||||
@@ -7,23 +7,14 @@ import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||
from langchain.chains import RefineDocumentsChain
|
||||
from langchain.chains.summarize import refine_prompts
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools.base import BaseTool
|
||||
from newspaper import Article
|
||||
from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
@@ -36,106 +27,6 @@ TEXT:
|
||||
"""
|
||||
|
||||
|
||||
class WebReaderToolInput(BaseModel):
|
||||
url: str = Field(..., description="URL of the website to read")
|
||||
summary: bool = Field(
|
||||
default=False,
|
||||
description="When the user's question requires extracting the summarizing content of the webpage, "
|
||||
"set it to true."
|
||||
)
|
||||
cursor: int = Field(
|
||||
default=0,
|
||||
description="Start reading from this character."
|
||||
"Use when the first response was truncated"
|
||||
"and you want to continue reading the page."
|
||||
"The value cannot exceed 24000.",
|
||||
)
|
||||
|
||||
|
||||
class WebReaderTool(BaseTool):
|
||||
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
|
||||
|
||||
name: str = "web_reader"
|
||||
args_schema: type[BaseModel] = WebReaderToolInput
|
||||
description: str = "use this to read a website. " \
|
||||
"If you can answer the question based on the information provided, " \
|
||||
"there is no need to use."
|
||||
page_contents: str = None
|
||||
url: str = None
|
||||
max_chunk_length: int = 4000
|
||||
summary_chunk_tokens: int = 4000
|
||||
summary_chunk_overlap: int = 0
|
||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||
continue_reading: bool = True
|
||||
model_config: ModelConfigEntity
|
||||
model_parameters: dict[str, Any]
|
||||
|
||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||
try:
|
||||
if not self.page_contents or self.url != url:
|
||||
page_contents = get_url(url)
|
||||
self.page_contents = page_contents
|
||||
self.url = url
|
||||
else:
|
||||
page_contents = self.page_contents
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
|
||||
if summary:
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=self.summary_chunk_tokens,
|
||||
chunk_overlap=self.summary_chunk_overlap,
|
||||
separators=self.summary_separators
|
||||
)
|
||||
|
||||
texts = character_splitter.split_text(page_contents)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
|
||||
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
|
||||
return "No content found."
|
||||
|
||||
# only use first 5 docs
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
chain = self.get_summary_chain()
|
||||
try:
|
||||
page_contents = chain.run(docs)
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
else:
|
||||
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
|
||||
|
||||
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
|
||||
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
|
||||
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
|
||||
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
|
||||
|
||||
return page_contents
|
||||
|
||||
async def _arun(self, url: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
refine_chain = LLMChain(
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.REFINE_PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name="text",
|
||||
initial_response_name="existing_answer",
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor: cursor + max_length]
|
||||
|
||||
@@ -116,6 +116,10 @@ class Dataset(db.Model):
|
||||
}
|
||||
return self.retrieval_model if self.retrieval_model else default_retrieval_model
|
||||
|
||||
@staticmethod
|
||||
def gen_collection_name_by_id(dataset_id: str) -> str:
|
||||
normalized_dataset_id = dataset_id.replace("-", "_")
|
||||
return f'Vector_index_{normalized_dataset_id}_Node'
|
||||
|
||||
class DatasetProcessRule(db.Model):
|
||||
__tablename__ = 'dataset_process_rules'
|
||||
|
||||
@@ -35,7 +35,7 @@ docx2txt==0.8
|
||||
pypdfium2==4.16.0
|
||||
resend~=0.7.0
|
||||
pyjwt~=2.8.0
|
||||
anthropic~=0.7.7
|
||||
anthropic~=0.17.0
|
||||
newspaper3k==0.2.8
|
||||
google-api-python-client==2.90.0
|
||||
wikipedia==1.4.0
|
||||
@@ -52,7 +52,7 @@ safetensors==0.3.2
|
||||
zhipuai==1.0.7
|
||||
werkzeug~=3.0.1
|
||||
pymilvus==2.3.0
|
||||
qdrant-client==1.6.4
|
||||
qdrant-client==1.7.3
|
||||
cohere~=4.44
|
||||
pyyaml~=6.0.1
|
||||
numpy~=1.25.2
|
||||
@@ -66,4 +66,5 @@ yfinance~=0.2.35
|
||||
pydub~=0.25.1
|
||||
gmpy2~=2.1.5
|
||||
numexpr~=2.9.0
|
||||
duckduckgo-search==4.4.3
|
||||
duckduckgo-search==4.4.3
|
||||
arxiv==2.1.0
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
|
||||
@@ -37,7 +37,7 @@ from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.feature_service import FeatureService
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.vector_service import VectorService
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
@@ -469,6 +469,9 @@ class DocumentService:
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = document_data["data_source"]["type"]
|
||||
@@ -619,6 +622,12 @@ class DocumentService:
|
||||
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def check_documents_upload_quota(count: int, features: FeatureModel):
|
||||
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
||||
if count > can_upload_size:
|
||||
raise ValueError(f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
|
||||
|
||||
@staticmethod
|
||||
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
@@ -763,6 +772,8 @@ class DocumentService:
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
|
||||
embedding_model = None
|
||||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
@@ -1244,7 +1255,7 @@ class DatasetCollectionBindingService:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node',
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type=collection_type
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
|
||||
@@ -25,6 +25,7 @@ class FeatureModel(BaseModel):
|
||||
apps: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
|
||||
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
|
||||
docs_processing: str = 'standard'
|
||||
can_replace_logo: bool = False
|
||||
|
||||
@@ -63,6 +64,9 @@ class FeatureService:
|
||||
features.vector_space.size = billing_info['vector_space']['size']
|
||||
features.vector_space.limit = billing_info['vector_space']['limit']
|
||||
|
||||
features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
|
||||
features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
|
||||
|
||||
features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
|
||||
features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
|
||||
|
||||
|
||||
@@ -209,8 +209,8 @@ class ToolManageService:
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
if len(tool_bundles) > 10:
|
||||
raise ValueError('the number of apis should be less than 10')
|
||||
if len(tool_bundles) > 100:
|
||||
raise ValueError('the number of apis should be less than 100')
|
||||
|
||||
# create db provider
|
||||
db_provider = ApiToolProvider(
|
||||
@@ -498,12 +498,16 @@ class ToolManageService:
|
||||
|
||||
@staticmethod
|
||||
def test_api_tool_preview(
|
||||
tenant_id: str, tool_name: str, credentials: dict, parameters: dict, schema_type: str, schema: str
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str
|
||||
):
|
||||
"""
|
||||
test api tool before adding api tool provider
|
||||
|
||||
1. parse schema into tool bundle
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema_type}')
|
||||
@@ -518,15 +522,21 @@ class ToolManageService:
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f'invalid tool name {tool_name}')
|
||||
|
||||
# create a fake db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id='', user_id='', name='', icon='',
|
||||
schema=schema,
|
||||
description='',
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
||||
tools_str=serialize_base_model_array(tool_bundles),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
|
||||
if not db_provider:
|
||||
# create a fake db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id='', user_id='', name='', icon='',
|
||||
schema=schema,
|
||||
description='',
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
||||
tools_str=serialize_base_model_array(tool_bundles),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
@@ -539,6 +549,19 @@ class ToolManageService:
|
||||
# load tools into provider entity
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfiguration(
|
||||
tenant_id=tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = decrypted_credentials[name]
|
||||
|
||||
try:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
|
||||
@@ -1,52 +1,87 @@
|
||||
import os
|
||||
from time import sleep
|
||||
from typing import Any, Generator, List, Literal, Union
|
||||
from typing import Any, Literal, Union, Iterable
|
||||
|
||||
from anthropic.resources import Messages
|
||||
from anthropic.types.message_delta_event import Delta
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from anthropic import Anthropic
|
||||
from anthropic._types import NOT_GIVEN, Body, Headers, NotGiven, Query
|
||||
from anthropic.resources.completions import Completions
|
||||
from anthropic.types import Completion, completion_create_params
|
||||
from anthropic import Anthropic, Stream
|
||||
from anthropic.types import MessageParam, Message, MessageStreamEvent, \
|
||||
ContentBlock, MessageStartEvent, Usage, TextDelta, MessageDeltaEvent, MessageStopEvent, ContentBlockDeltaEvent, \
|
||||
MessageDeltaUsage
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
||||
|
||||
|
||||
class MockAnthropicClass(object):
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Completion:
|
||||
return Completion(
|
||||
completion='hello, I\'m a chatbot from anthropic',
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Message:
|
||||
return Message(
|
||||
id='msg-123',
|
||||
type='message',
|
||||
role='assistant',
|
||||
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
|
||||
model=model,
|
||||
stop_reason='stop_sequence'
|
||||
stop_reason='stop_sequence',
|
||||
usage=Usage(
|
||||
input_tokens=1,
|
||||
output_tokens=1
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_stream(model: str) -> Generator[Completion, None, None]:
|
||||
def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]:
|
||||
full_response_text = "hello, I'm a chatbot from anthropic"
|
||||
|
||||
for i in range(0, len(full_response_text) + 1):
|
||||
sleep(0.1)
|
||||
if i == len(full_response_text):
|
||||
yield Completion(
|
||||
completion='',
|
||||
model=model,
|
||||
stop_reason='stop_sequence'
|
||||
)
|
||||
else:
|
||||
yield Completion(
|
||||
completion=full_response_text[i],
|
||||
model=model,
|
||||
stop_reason=''
|
||||
yield MessageStartEvent(
|
||||
type='message_start',
|
||||
message=Message(
|
||||
id='msg-123',
|
||||
content=[],
|
||||
role='assistant',
|
||||
model=model,
|
||||
stop_reason=None,
|
||||
type='message',
|
||||
usage=Usage(
|
||||
input_tokens=1,
|
||||
output_tokens=1
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def mocked_anthropic(self: Completions, *,
|
||||
max_tokens_to_sample: int,
|
||||
model: Union[str, Literal["claude-2.1", "claude-instant-1"]],
|
||||
prompt: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any
|
||||
) -> Union[Completion, Generator[Completion, None, None]]:
|
||||
index = 0
|
||||
for i in range(0, len(full_response_text)):
|
||||
sleep(0.1)
|
||||
yield ContentBlockDeltaEvent(
|
||||
type='content_block_delta',
|
||||
delta=TextDelta(text=full_response_text[i], type='text_delta'),
|
||||
index=index
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
yield MessageDeltaEvent(
|
||||
type='message_delta',
|
||||
delta=Delta(
|
||||
stop_reason='stop_sequence'
|
||||
),
|
||||
usage=MessageDeltaUsage(
|
||||
output_tokens=1
|
||||
)
|
||||
)
|
||||
|
||||
yield MessageStopEvent(type='message_stop')
|
||||
|
||||
def mocked_anthropic(self: Messages, *,
|
||||
max_tokens: int,
|
||||
messages: Iterable[MessageParam],
|
||||
model: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any
|
||||
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||
if len(self._client.api_key) < 18:
|
||||
raise anthropic.AuthenticationError('Invalid API key')
|
||||
|
||||
@@ -55,12 +90,13 @@ class MockAnthropicClass(object):
|
||||
else:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Completions, 'create', MockAnthropicClass.mocked_anthropic)
|
||||
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
||||
@@ -15,14 +15,14 @@ def test_validate_credentials(setup_anthropic_mock):
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
}
|
||||
@@ -33,7 +33,7 @@ def test_invoke_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
|
||||
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
|
||||
@@ -49,7 +49,7 @@ def test_invoke_model(setup_anthropic_mock):
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'max_tokens_to_sample': 10
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
@@ -64,7 +64,7 @@ def test_invoke_stream_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
@@ -78,7 +78,7 @@ def test_invoke_stream_model(setup_anthropic_mock):
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens_to_sample': 100
|
||||
'max_tokens': 100
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
@@ -97,7 +97,7 @@ def test_get_num_tokens():
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='claude-instant-1',
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.5.7
|
||||
image: langgenius/dify-api:0.5.8
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -135,7 +135,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.5.7
|
||||
image: langgenius/dify-api:0.5.8
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -206,7 +206,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.5.7
|
||||
image: langgenius/dify-web:0.5.8
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@@ -153,18 +153,18 @@ const AgentTools: FC = () => {
|
||||
)
|
||||
: (
|
||||
<div className='hidden group-hover:flex items-center'>
|
||||
{item.provider_type === CollectionType.builtIn && (
|
||||
<TooltipPlus
|
||||
popupContent={t('tools.setBuiltInTools.infoAndSetting')}
|
||||
>
|
||||
<div className='mr-1 p-1 rounded-md hover:bg-black/5 cursor-pointer' onClick={() => {
|
||||
setCurrentTool(item)
|
||||
setIsShowSettingTool(true)
|
||||
}}>
|
||||
<InfoCircle className='w-4 h-4 text-gray-500' />
|
||||
</div>
|
||||
</TooltipPlus>
|
||||
)}
|
||||
{/* {item.provider_type === CollectionType.builtIn && ( */}
|
||||
<TooltipPlus
|
||||
popupContent={t('tools.setBuiltInTools.infoAndSetting')}
|
||||
>
|
||||
<div className='mr-1 p-1 rounded-md hover:bg-black/5 cursor-pointer' onClick={() => {
|
||||
setCurrentTool(item)
|
||||
setIsShowSettingTool(true)
|
||||
}}>
|
||||
<InfoCircle className='w-4 h-4 text-gray-500' />
|
||||
</div>
|
||||
</TooltipPlus>
|
||||
{/* )} */}
|
||||
|
||||
<div className='p-1 rounded-md hover:bg-black/5 cursor-pointer' onClick={() => {
|
||||
const newModelConfig = produce(modelConfig, (draft) => {
|
||||
@@ -209,6 +209,7 @@ const AgentTools: FC = () => {
|
||||
toolName={currentTool?.tool_name as string}
|
||||
setting={currentTool?.tool_parameters as any}
|
||||
collection={currentTool?.collection as Collection}
|
||||
isBuiltIn={currentTool?.collection?.type === CollectionType.builtIn}
|
||||
onSave={handleToolSettingChange}
|
||||
onHide={() => setIsShowSettingTool(false)}
|
||||
/>)
|
||||
|
||||
@@ -8,14 +8,17 @@ import Drawer from '@/app/components/base/drawer-plus'
|
||||
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
|
||||
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
||||
import type { Collection, Tool } from '@/app/components/tools/types'
|
||||
import { fetchBuiltInToolList } from '@/service/tools'
|
||||
import { fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools'
|
||||
import I18n from '@/context/i18n'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { DiagonalDividingLine } from '@/app/components/base/icons/src/public/common'
|
||||
import { getLanguage } from '@/i18n/language'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
|
||||
type Props = {
|
||||
collection: Collection
|
||||
isBuiltIn?: boolean
|
||||
toolName: string
|
||||
setting?: Record<string, any>
|
||||
readonly?: boolean
|
||||
@@ -25,6 +28,7 @@ type Props = {
|
||||
|
||||
const SettingBuiltInTool: FC<Props> = ({
|
||||
collection,
|
||||
isBuiltIn = true,
|
||||
toolName,
|
||||
setting = {},
|
||||
readonly,
|
||||
@@ -52,7 +56,7 @@ const SettingBuiltInTool: FC<Props> = ({
|
||||
(async () => {
|
||||
setIsLoading(true)
|
||||
try {
|
||||
const list = await fetchBuiltInToolList(collection.name)
|
||||
const list = isBuiltIn ? await fetchBuiltInToolList(collection.name) : await fetchCustomToolList(collection.name)
|
||||
setTools(list)
|
||||
const currTool = list.find(tool => tool.name === toolName)
|
||||
if (currTool) {
|
||||
@@ -135,12 +139,24 @@ const SettingBuiltInTool: FC<Props> = ({
|
||||
onHide={onHide}
|
||||
title={(
|
||||
<div className='flex'>
|
||||
<div
|
||||
className='w-6 h-6 bg-cover bg-center rounded-md'
|
||||
style={{
|
||||
backgroundImage: `url(${collection.icon})`,
|
||||
}}
|
||||
></div>
|
||||
{collection.icon === 'string'
|
||||
? (
|
||||
<div
|
||||
className='w-6 h-6 bg-cover bg-center rounded-md'
|
||||
style={{
|
||||
backgroundImage: `url(${collection.icon})`,
|
||||
}}
|
||||
></div>
|
||||
)
|
||||
: (
|
||||
<AppIcon
|
||||
className='rounded-md'
|
||||
size='tiny'
|
||||
icon={(collection.icon as any)?.content}
|
||||
background={(collection.icon as any)?.background}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className='ml-2 leading-6 text-base font-semibold text-gray-900'>{currTool?.label[language]}</div>
|
||||
{(hasSetting && !readonly) && (<>
|
||||
<DiagonalDividingLine className='mx-4' />
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user