Compare commits

..

51 Commits

Author SHA1 Message Date
Joel
d44d4bd6fd feat: support query date tool (#662) 2023-07-27 22:27:05 +08:00
John Wang
2adaceab82 feat: bump version to 0.3.11 (#654) 2023-07-27 22:25:32 +08:00
John Wang
d979955c8a feat: optimize current time (#661) 2023-07-27 22:15:07 +08:00
Joel
eae670ea4a feat: enchance chat user experience (#660) 2023-07-27 18:04:41 +08:00
John Wang
b5825142d1 feat: add current time tool in universal chat agent (#659) 2023-07-27 17:39:36 +08:00
Joel
741e9303d4 fix: use sharp logo replace old logo (#658) 2023-07-27 16:34:30 +08:00
John Wang
538e3fc256 fix: return message error in blocking mode (#657) 2023-07-27 16:14:45 +08:00
John Wang
ba3dc8cae0 feat: fix dataset retrieve agent llm not support error (#656) 2023-07-27 15:45:52 +08:00
zxhlyh
ae7c0380dc Feat/application api add speech to text (#655) 2023-07-27 14:53:19 +08:00
Joel
23e3413655 feat: chat in explore support agent (#647)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-07-27 13:27:34 +08:00
John Wang
4fdb37771a feat: universal chat in explore (#649)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-07-27 13:08:57 +08:00
TheFu527
94b54b7ca9 feat: replace the end user column in the web page Log & Ann. with the… (#653)
Co-authored-by: Hao Fu <hao.fu@helloklarity.com>
2023-07-27 12:48:43 +08:00
crazywoola
f9412f5fdb fix: site enable check (#645) 2023-07-26 11:11:09 +08:00
zxhlyh
1d6829f400 Feat/application config user input field collapse (#643) 2023-07-26 10:27:52 +08:00
zxhlyh
f8bae897e5 fix: switch workspace (#642) 2023-07-26 10:25:35 +08:00
Selenium39
dd1172b57e Perf: Support for password display and hiding (#636)
Co-authored-by: Selenium39 <selenium39@qq.com>
2023-07-24 14:48:00 +08:00
Rhon Joe
67d326a558 fix(web): fix svg unrecognized props (#631) 2023-07-24 10:31:56 +08:00
zxhlyh
fe747040bc downgrade next version (#626) 2023-07-21 12:27:23 +08:00
Rhon Joe
7d6c925cbc fix(web): using Tooltip unique selector key (#622) 2023-07-21 11:15:00 +08:00
Joel
f488d06b20 fix: Top P description error (#624) 2023-07-21 09:15:52 +08:00
Rhon Joe
c00a19ced3 fix(web): fix Embedded copy status when toggle options (#621) 2023-07-21 09:06:51 +08:00
John Wang
e9810a6df2 fix: azure openai embedding model name error (#612) 2023-07-20 13:52:54 +08:00
John Wang
cae15013e0 fix: azure openai deployment list was deprecated suddenly (#611) 2023-07-20 13:46:39 +08:00
Jyong
52c84da051 add clean unused dataset command (#609) 2023-07-20 11:08:28 +08:00
Jyong
026f0bfce9 Feat/clean vector dataset (#605) 2023-07-19 21:30:25 +08:00
Joel
d19181fb29 chore: minify embed js (#604) 2023-07-19 19:48:44 +08:00
Yuhao
2f9de2229f feat: embed into other site support set custom host (#580)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-07-19 19:43:07 +08:00
Rhon Joe
34f55739e0 fix(web): fix #596 copy-to-clipboard issue (#602) 2023-07-19 19:29:37 +08:00
Joel
668b059c07 fix: quick switch and click create conversation button may caused fetch conversation list error (#603) 2023-07-19 17:17:29 +08:00
zxhlyh
753e5f1500 Fix/application configuration preview style (#597) 2023-07-19 12:41:35 +08:00
zxhlyh
a6af8e5d8f Fix/new conversation in mobile phone (#593) 2023-07-18 16:57:28 +08:00
zxhlyh
3e1d5ac51b Feat/header ssr (#594) 2023-07-18 16:57:14 +08:00
John Wang
b0091452ca feat: add bash before entrypoint.sh in Dockerfile (#592) 2023-07-18 16:22:34 +08:00
John Wang
eff115267f fix: anthropic completion error in blocking mode (#591) 2023-07-18 15:12:52 +08:00
John Wang
07cde4f8fe feat: bump 0.3.10 (#589) 2023-07-18 15:04:49 +08:00
Jyong
9f28a48a92 index add to db when dataset updated (#588) 2023-07-18 15:02:33 +08:00
John Wang
0d3cd3b16a fix: azure provider select error when use custom azure provider (#587) 2023-07-18 14:34:09 +08:00
John Wang
3dc82fb044 feat: remove davinci required model from azure provider (#586) 2023-07-18 14:14:56 +08:00
crazywoola
cb6e73347e Feat/add ruby sdk (#583) 2023-07-18 10:18:58 +08:00
zxhlyh
ecd6cbaee6 Fix/use embedded chatbot with no track mode (#582) 2023-07-18 09:45:17 +08:00
KVOJJJin
d54e942264 Feat: hide password setting and invitation link in cloud version (#581) 2023-07-18 08:54:14 +08:00
Panmuse
28ba721455 Update README_CN.md (#575) 2023-07-17 11:08:26 +08:00
Panmuse
784dd7848e Update README.md (#576) 2023-07-17 11:08:03 +08:00
John Wang
e2a5f8ba1a feat: bump version to 0.3.9 (#574) 2023-07-17 09:47:23 +08:00
Joel
8e11200306 feat: frontend support claude (#573)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-07-17 00:14:32 +08:00
John Wang
7599f79a17 feat: claude api support (#572) 2023-07-17 00:14:19 +08:00
Joel
510389909c fix: change chatbot avart to dify icon (#571) 2023-07-16 16:30:55 +08:00
Jyong
2c6e00174b add document limit check (#570) 2023-07-16 13:21:56 +08:00
John Wang
24f3456990 fix: account check in runtime (#569) 2023-07-15 23:58:15 +08:00
Joel
20514ff288 fix: table too wide fix text generation ui (#566) 2023-07-14 18:15:56 +08:00
zxhlyh
381d255290 fix setting-modal provider encrypted tip style (#565) 2023-07-14 17:10:02 +08:00
337 changed files with 10403 additions and 2287 deletions

View File

@@ -19,7 +19,7 @@ def check_file_for_chinese_comments(file_path):
def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py']
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
for root, _, files in os.walk("."):
for file in files:

View File

@@ -17,9 +17,15 @@ A single API encompassing plugin capabilities, context enhancement, and more, sa
Visual data analysis, log review, and annotation for applications
Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported:
- GPT 3 (text-davinci-003)
- GPT 3.5 Turbo(ChatGPT)
- GPT-4
* **OpenAI** GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
* **Azure OpenAI**
* **Antropic**Claude2、Claude-instant
> We've got 1000 free trial credits available for all cloud service users to try out the Claude model.Visit [Dify.ai](https://dify.ai) and
try it now.
* **hugging face Hub**Coming soon.
## Use Cloud Services

View File

@@ -17,11 +17,16 @@
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
- 可视化的对应用进行数据分析,查阅日志或进行标注
Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持:
Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商
- GPT 3 (text-davinci-003)
- GPT 3.5 Turbo(ChatGPT)
- GPT-4
* **OpenAI**GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
* **Azure OpenAI Service**
* **Anthropic**Claude2、Claude-instant
> 我们为所有注册云端版的用户免费提供了 1000 次 Claude 模型的消息调用额度,登录 [dify.ai](https://cloud.dify.ai) 即可使用。
* **Hugging Face Hub**(即将推出)
## 使用云服务

View File

@@ -27,4 +27,4 @@ RUN chmod +x /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA ${COMMIT_SHA}
ENTRYPOINT ["/entrypoint.sh"]
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

View File

@@ -2,6 +2,8 @@
import os
from datetime import datetime
from werkzeug.exceptions import Forbidden
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
@@ -20,14 +22,14 @@ from extensions.ext_database import db
from extensions.ext_login import login_manager
# DO NOT REMOVE BELOW
from models import model, account, dataset, web, task, source
from models import model, account, dataset, web, task, source, tool
from events import event_handlers
# DO NOT REMOVE ABOVE
import core
from config import Config, CloudEditionConfig
from commands import register_commands
from models.account import TenantAccountJoin
from models.account import TenantAccountJoin, AccountStatus
from models.model import Account, EndUser, App
import warnings
@@ -101,6 +103,9 @@ def load_user(user_id):
account = db.session.query(Account).filter(Account.id == account_id).first()
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(

View File

@@ -2,6 +2,7 @@ import datetime
import logging
import random
import string
import time
import click
from flask import current_app
@@ -13,12 +14,13 @@ from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant
from models.dataset import Dataset
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
from models.model import Account
import secrets
import base64
from models.provider import Provider
from models.provider import Provider, ProviderName
from services.provider_service import ProviderService
@click.command('reset-password', help='Reset the account password.')
@@ -171,7 +173,7 @@ def recreate_all_dataset_indexes():
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
@@ -187,15 +189,103 @@ def recreate_all_dataset_indexes():
else:
click.echo('passed.')
except Exception as e:
click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.echo(
click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
@click.command('clean-unused-dataset-indexes', help='Clean unused dataset indexes.')
def clean_unused_dataset_indexes():
click.echo(click.style('Start clean unused dataset indexes.', fg='green'))
clean_days = int(current_app.config.get('CLEAN_DAY_SETTING'))
start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
dataset_query = db.session.query(DatasetQuery).filter(
DatasetQuery.created_at > thirty_days_ago,
DatasetQuery.dataset_id == dataset.id
).all()
if not dataset_query or len(dataset_query) == 0:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset.id,
Document.indexing_status == 'completed',
Document.enabled == True,
Document.archived == False,
Document.updated_at > thirty_days_ago
).all()
if not documents or len(documents) == 0:
try:
# remove index
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete()
kw_index.delete()
# update document
update_params = {
Document.enabled: False
}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
fg='green'))
except Exception as e:
click.echo(
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
end_at = time.perf_counter()
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
page = 1
while True:
try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for tenant in tenants:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
count += 1
except Exception as e:
click.echo(click.style(
'Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)
app.cli.add_command(clean_unused_dataset_indexes)

View File

@@ -50,7 +50,11 @@ DEFAULTS = {
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai'
'DEFAULT_LLM_PROVIDER': 'openai',
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30
}
@@ -86,7 +90,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.8"
self.CURRENT_VERSION = "0.3.11"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -191,6 +195,10 @@ class Config:
# hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
@@ -207,6 +215,9 @@ class Config:
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
class CloudEditionConfig(Config):

View File

@@ -18,7 +18,10 @@ from .auth import login, oauth, data_source_oauth, activate
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers
from .workspace import workspace, members, providers, account
from .workspace import workspace, members, model_providers, account, tool_providers
# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio

View File

@@ -24,6 +24,7 @@ model_config_fields = {
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String,
@@ -96,7 +97,8 @@ class AppListApi(Resource):
args = parser.parse_args()
app_models = db.paginate(
db.select(App).where(App.tenant_id == current_user.current_tenant_id).order_by(App.created_at.desc()),
db.select(App).where(App.tenant_id == current_user.current_tenant_id,
App.is_universal == False).order_by(App.created_at.desc()),
page=args['page'],
per_page=args['limit'],
error_out=False)
@@ -147,6 +149,7 @@ class AppListApi(Resource):
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
@@ -438,6 +441,7 @@ class AppCopy(Resource):
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this,
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
model=app_config.model,
user_input_form=app_config.user_input_form,
pre_prompt=app_config.pre_prompt,

View File

@@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:

View File

@@ -95,6 +95,7 @@ class CompletionConversationApi(Resource):
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(attribute='end_user.session_id'),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
@@ -135,6 +136,8 @@ class CompletionConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion')
query = query.options(joinedload(Conversation.end_user))
if args['keyword']:
query = query.join(
Message, Message.conversation_id == Conversation.id
@@ -160,7 +163,7 @@ class CompletionConversationApi(Resource):
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
@@ -246,6 +249,7 @@ class ChatConversationApi(Resource):
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(attribute='end_user.session_id'),
'from_account_id': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
@@ -288,6 +292,8 @@ class ChatConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat')
query = query.options(joinedload(Conversation.end_user))
if args['keyword']:
query = query.join(
Message, Message.conversation_id == Conversation.id
@@ -316,7 +322,7 @@ class ChatConversationApi(Resource):
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)

View File

@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
class ProviderQuotaExceededError(BaseHTTPException):
error_code = 'provider_quota_exceeded'
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
"Please go to Settings -> Model Provider to complete your own provider credentials."
code = 400

View File

@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
account.current_tenant_id,
args['prompt_template']
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
args['audiences'],
args['hoping_to_solve']
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -43,6 +43,7 @@ class ModelConfigResource(Resource):
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],

View File

@@ -3,7 +3,6 @@ from flask import request
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError

View File

@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
document_data=args,
account=current_user
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -95,8 +95,8 @@ class HitTestingApi(Resource):
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:

View File

@@ -65,7 +65,10 @@ class ConversationApi(InstalledAppResource):
raise NotChatAppError()
conversation_id = str(c_id)
ConversationService.delete(app_model, conversation_id, current_user)
try:
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204

View File

@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import InstalledApp
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
@@ -27,16 +31,17 @@ class AppParameterApi(InstalledAppResource):
}
@marshal_with(parameters_fields)
def get(self, installed_app):
def get(self, installed_app: InstalledApp):
"""Retrieve app parameters."""
app_model = installed_app.app
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -0,0 +1,66 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
from models.model import AppModelConfig
class UniversalChatAudioApi(UniversalChatResource):
def post(self, universal_app):
app_model = universal_app
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text')

View File

@@ -0,0 +1,142 @@
import json
import logging
from typing import Generator, Union
from flask import Response, stream_with_context
from flask_login import current_user
from flask_restful import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.constant import llm_constant
from core.conversation_message_task import PubHandler
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
from libs.helper import uuid_value
from services.completion_service import CompletionService
class UniversalChatApi(UniversalChatResource):
def post(self, universal_app):
app_model = universal_app
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('model', type=str, required=True, location='json')
parser.add_argument('tools', type=list, required=True, location='json')
args = parser.parse_args()
app_model_config = app_model.app_model_config
# update app model config
args['model_config'] = app_model_config.to_dict()
args['model_config']['model']['name'] = args['model']
if not llm_constant.models[args['model']]:
raise ValueError("Model not exists.")
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
args['model_config']['agent_mode']['tools'] = args['tools']
if not args['model_config']['agent_mode']['tools']:
args['model_config']['agent_mode']['tools'] = [
{
"current_datetime": {
"enabled": True
}
}
]
else:
args['model_config']['agent_mode']['tools'].append({
"current_datetime": {
"enabled": True
}
})
args['inputs'] = {}
del args['model']
del args['tools']
try:
response = CompletionService.completion(
app_model=app_model,
user=current_user,
args=args,
from_source='console',
streaming=True,
is_model_config_override=True,
)
return compact_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class UniversalChatStopApi(UniversalChatResource):
def post(self, universal_app, task_id):
PubHandler.stop(current_user, task_id)
return {'result': 'success'}, 200
def compact_response(response: Union[dict | Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
try:
for chunk in response:
yield chunk
except services.errors.conversation.ConversationNotExistsError:
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
except services.errors.conversation.ConversationCompletedError:
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
logging.exception("internal server error.")
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
api.add_resource(UniversalChatApi, '/universal-chat/messages')
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')

View File

@@ -0,0 +1,118 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from flask_restful import fields, reqparse, marshal_with
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource
from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField,
'model_config': fields.Raw,
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class UniversalChatConversationListApi(UniversalChatResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, universal_app):
app_model = universal_app
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
args = parser.parse_args()
pinned = None
if 'pinned' in args and args['pinned'] is not None:
pinned = True if args['pinned'] == 'true' else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=current_user,
last_id=args['last_id'],
limit=args['limit'],
pinned=pinned
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class UniversalChatConversationApi(UniversalChatResource):
def delete(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)
try:
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204
class UniversalChatConversationRenameApi(UniversalChatResource):
@marshal_with(conversation_fields)
def post(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
class UniversalChatConversationPinApi(UniversalChatResource):
def patch(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)
try:
WebConversationService.pin(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}
class UniversalChatConversationUnPinApi(UniversalChatResource):
def patch(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')

View File

@@ -0,0 +1,127 @@
# -*- coding:utf-8 -*-
import logging
from flask_login import current_user
from flask_restful import reqparse, fields, marshal_with
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound, InternalServerError
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
class UniversalChatMessageListApi(UniversalChatResource):
feedback_fields = {
'rating': fields.String
}
agent_thought_fields = {
'id': fields.String,
'chain_id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'thought': fields.String,
'tool': fields.String,
'tool_input': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, universal_app):
app_model = universal_app
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
try:
return MessageService.pagination_by_first_id(app_model, current_user,
args['conversation_id'], args['first_id'], args['limit'])
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.message.FirstMessageNotExistsError:
raise NotFound("First Message Not Exists.")
class UniversalChatMessageFeedbackApi(UniversalChatResource):
def post(self, universal_app, message_id):
app_model = universal_app
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success'}
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
def get(self, universal_app, message_id):
app_model = universal_app
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=current_user,
message_id=message_id
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
return {'data': questions}
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')

View File

@@ -0,0 +1,36 @@
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, fields
from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
class UniversalChatParameterApi(UniversalChatResource):
"""Resource for app variables."""
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
}
@marshal_with(parameters_fields)
def get(self, universal_app: App):
"""Retrieve app parameters."""
app_model = universal_app
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
}
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')

View File

@@ -0,0 +1,84 @@
import json
from functools import wraps
from flask_login import login_required, current_user
from flask_restful import Resource
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from models.model import App, AppModelConfig
def universal_chat_app_required(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
# get universal chat app
universal_app = db.session.query(App).filter(
App.tenant_id == current_user.current_tenant_id,
App.is_universal == True
).first()
if universal_app is None:
# create universal app if not exists
universal_app = App(
tenant_id=current_user.current_tenant_id,
name='Universal Chat',
mode='chat',
is_universal=True,
icon='',
icon_background='',
api_rpm=0,
api_rph=0,
enable_site=False,
enable_api=False,
status='normal'
)
db.session.add(universal_app)
db.session.flush()
app_model_config = AppModelConfig(
provider="",
model_id="",
configs={},
opening_statement='',
suggested_questions=json.dumps([]),
suggested_questions_after_answer=json.dumps({'enabled': True}),
speech_to_text=json.dumps({'enabled': True}),
more_like_this=None,
sensitive_word_avoidance=None,
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-16k",
"completion_params": {
"max_tokens": 800,
"temperature": 0.8,
"top_p": 1,
"presence_penalty": 0,
"frequency_penalty": 0
}
}),
user_input_form=json.dumps([]),
pre_prompt='',
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
)
app_model_config.app_id = universal_app.id
db.session.add(app_model_config)
db.session.flush()
universal_app.app_model_config_id = app_model_config.id
db.session.commit()
return view(universal_app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
class UniversalChatResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]

View File

@@ -3,6 +3,7 @@ import base64
import json
import logging
from flask import current_app
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort
from werkzeug.exceptions import Forbidden
@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
ProviderService.init_supported_provider(current_user.current_tenant)
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
provider_list = [
@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
ProviderName(p.provider_name))
ProviderName(p.provider_name), only_custom=True)
if p.provider_type == ProviderType.CUSTOM.value else None
}
for p in providers
]
@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
is_valid=token_is_valid)
db.session.add(provider_model)
if provider_model.is_valid:
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
db.session.commit()
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
args = parser.parse_args()
# todo: remove this when the provider is supported
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
if provider in [ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
provider_model.is_valid = args['is_enabled']
db.session.commit()
elif not provider_model:
ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
if provider == ProviderName.OPENAI.value:
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
elif provider == ProviderName.ANTHROPIC.value:
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
else:
quota_limit = 0
ProviderService.create_system_provider(
tenant,
provider,
quota_limit,
args['is_enabled']
)
else:
abort(403)

View File

@@ -0,0 +1,136 @@
import json
from flask_login import login_required, current_user
from flask_restful import Resource, abort, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.provider.tool_provider_service import ToolProviderService
from extensions.ext_database import db
from models.tool import ToolProvider, ToolProviderName
class ToolProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
tool_credential_dict = {}
for tool_name in ToolProviderName:
tool_credential_dict[tool_name.value] = {
'tool_name': tool_name.value,
'is_enabled': False,
'credentials': None
}
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
for p in tool_providers:
if p.is_enabled:
tool_credential_dict[p.tool_name] = {
'tool_name': p.tool_name,
'is_enabled': p.is_enabled,
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
}
return list(tool_credential_dict.values())
class ToolProviderCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ToolProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
f'current_role is {current_user.current_tenant.current_role}')
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
tool_provider_service = ToolProviderService(tenant_id, provider)
try:
tool_provider_service.credentials_validate(args['credentials'])
except ToolValidateFailedError as ex:
raise ValueError(str(ex))
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
tenant = current_user.current_tenant
tool_provider_model = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == tenant.id,
ToolProvider.tool_name == provider,
).first()
# Only allow updating token for CUSTOM provider type
if tool_provider_model:
tool_provider_model.encrypted_credentials = encrypted_credentials
tool_provider_model.is_enabled = True
else:
tool_provider_model = ToolProvider(
tenant_id=tenant.id,
tool_name=provider,
encrypted_credentials=encrypted_credentials,
is_enabled=True
)
db.session.add(tool_provider_model)
db.session.commit()
return {'result': 'success'}, 201
class ToolProviderCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ToolProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
result = True
error = None
tenant_id = current_user.current_tenant_id
tool_provider_service = ToolProviderService(tenant_id, provider)
try:
tool_provider_service.credentials_validate(args['credentials'])
except ToolValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
api.add_resource(ToolProviderCredentialsValidateApi,
'/workspaces/current/tool-providers/<provider>/credentials-validate')

View File

@@ -4,6 +4,10 @@ from flask_restful import fields, marshal_with
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
class AppParameterApi(AppApiResource):
"""Resource for app variables."""
@@ -28,15 +32,16 @@ class AppParameterApi(AppApiResource):
}
@marshal_with(parameters_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:

View File

@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
dataset_process_rule=dataset.latest_process_rule,
created_from='api'
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]

View File

@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
from controllers.web import api
from controllers.web.wraps import WebApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
class AppParameterApi(WebApiResource):
"""Resource for app variables."""
@@ -27,15 +31,16 @@ class AppParameterApi(WebApiResource):
}
@marshal_with(parameters_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:

View File

@@ -62,7 +62,10 @@ class ConversationApi(WebApiResource):
raise NotChatAppError()
conversation_id = str(c_id)
ConversationService.delete(app_model, conversation_id, end_user)
try:
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}, 204

View File

@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:

View File

@@ -38,6 +38,8 @@ def decode_jwt_token():
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
if not app_model:
raise NotFound()
if app_model.enable_site is False:
raise Unauthorized('Site is disabled.')
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
if not end_user:
raise NotFound()

View File

@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
api_key: str
class HostedAnthropicCredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None
anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials()
@@ -26,3 +31,6 @@ def init_app(app: Flask):
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
if app.config.get("ANTHROPIC_API_KEY"):
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

View File

@@ -0,0 +1,35 @@
from typing import cast, List
from langchain import OpenAI
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseMessage
from core.constant import llm_constant
class CalcTokenMixin:
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
llm = cast(ChatOpenAI, llm)
return llm.get_num_tokens_from_messages(messages)
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param llm:
:param messages:
:return:
"""
llm = cast(ChatOpenAI, llm)
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
completion_max_tokens = llm.max_tokens
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass

View File

@@ -0,0 +1,83 @@
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
from langchain.tools import BaseTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
return super().plan(intermediate_steps, callbacks, **kwargs)
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
**kwargs,
)

View File

@@ -0,0 +1,112 @@
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
_format_intermediate_steps
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return True if function_call else False
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def get_system_message(cls):
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
try:
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")

View File

@@ -0,0 +1,132 @@
from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
summary_handler = SummarizerMixin(llm=self.summary_llm)
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, llm)
model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -0,0 +1,102 @@
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.agents import BaseMultiActionAgent
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
_parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseMultiActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return True if function_call else False
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def get_system_message(cls):
# get current time
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")

View File

@@ -0,0 +1,29 @@
import json
import re
from typing import Union
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser, \
logger
from langchain.schema import AgentAction, AgentFinish, OutputParserException
class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
action_match = re.search(r"```(.*?)\n(.*?)```?", text, re.DOTALL)
if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list):
# gpt turbo frequently ignores the directive to emit a single action
logger.warning("Got multiple action responses: %s", response)
response = response[0]
if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text)
else:
return AgentAction(
response["action"], response.get("action_input", {}), text
)
else:
return AgentFinish({"output": text}, text)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e

View File

@@ -0,0 +1,182 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = []
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
self.moving_summary_index = len(intermediate_steps)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
summary_handler = SummarizerMixin(llm=self.summary_llm)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
if 'chat_history' in kwargs:
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
**kwargs,
)

View File

@@ -1,86 +0,0 @@
from typing import Optional
from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks.manager import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.llm.llm_builder import LLMBuilder
class AgentBuilder:
@classmethod
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name,
temperature=0,
max_tokens=1024,
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
)
for tool in tools:
tool.callbacks = [
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
]
prompt = cls.build_agent_prompt_template(
tools=tools,
memory=memory,
)
agent_llm_chain = LLMChain(
llm=llm,
prompt=prompt,
)
agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
agent_callback_manager = CallbackManager(
[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
)
agent_chain = AgentExecutor.from_agent_and_tools(
tools=tools,
agent=agent,
memory=memory,
callbacks=agent_callback_manager,
max_iterations=6,
early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
)
return agent_chain
@classmethod
def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
if memory:
prompt = ConversationalAgent.create_prompt(
tools=tools,
)
else:
prompt = ZeroShotAgent.create_prompt(
tools=tools,
)
return prompt
@classmethod
def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
if memory:
agent = ConversationalAgent(
llm_chain=agent_llm_chain
)
else:
agent = ZeroShotAgent(
llm_chain=agent_llm_chain
)
return agent

View File

@@ -0,0 +1,122 @@
import enum
import logging
from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
llm: BaseLanguageModel
tools: list[BaseTool]
summary_llm: BaseLanguageModel
dataset_llm: BaseLanguageModel
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
class AgentExecuteResult(BaseModel):
strategy: PlanningStrategy
output: Optional[str]
configuration: AgentConfiguration
class AgentExecutor:
def __init__(self, configuration: AgentConfiguration):
self.configuration = configuration
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.dataset_llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
return agent
def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult:
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
memory=self.configuration.memory,
max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method,
callbacks=self.configuration.callbacks
)
try:
output = agent_executor.run(query)
except Exception:
logging.exception("agent_executor run failed")
output = None
return AgentExecuteResult(
output=output,
strategy=self.configuration.strategy,
configuration=self.configuration
)

View File

@@ -1,10 +1,12 @@
import json
import logging
import time
from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
self.current_chain = None
@property
@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def clear_agent_loops(self) -> None:
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
@property
def always_verbose(self) -> bool:
@@ -61,9 +65,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
# kwargs={}
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
if 'function_call' in completion_message.additional_kwargs:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
else:
self._current_loop.completion = response.generations[0][0].text
else:
self._current_loop.completion = completion_generation.text
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -71,6 +87,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
def on_tool_start(
self,
@@ -89,15 +106,29 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
) -> Any:
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
tool_input = json.dumps({"query": action.tool_input}
if isinstance(action.tool_input, str) else action.tool_input)
completion = None
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
thought = action.log.strip()
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
else:
action_name_position = action.log.index("Action:") if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
if self._current_loop and self._current_loop.status == 'llm_end':
self._current_loop.status = 'agent_action'
self._current_loop.thought = thought
self._current_loop.tool_name = tool
self._current_loop.tool_input = tool_input
if completion is not None:
self._current_loop.completion = completion
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
def on_tool_end(
self,
@@ -120,10 +151,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -132,6 +166,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
@@ -141,10 +176,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._current_loop.thought = '[DONE]'
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish'

View File

@@ -1,3 +1,4 @@
import json
import logging
from typing import Any, Dict, List, Union, Optional
@@ -43,9 +44,11 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str: str,
**kwargs: Any,
) -> None:
tool_name = serialized.get('name')
dataset_id = tool_name[len("dataset-"):]
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
# tool_name = serialized.get('name')
input_dict = json.loads(input_str.replace("'", "\""))
dataset_id = input_dict.get('dataset_id')
query = input_dict.get('query')
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(
self,

View File

@@ -10,9 +10,9 @@ class AgentLoop(BaseModel):
tool_output: str = None
prompt: str = None
prompt_tokens: int = None
prompt_tokens: int = 0
completion: str = None
completion_tokens: int = None
completion_tokens: int = 0
latency: float = None

View File

@@ -1,20 +1,18 @@
import logging
import time
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
def __init__(self, llm: BaseLanguageModel,
conversation_message_task: ConversationMessageTask):
self.llm = llm
self.llm_message = LLMMessage()
@@ -48,7 +46,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -69,9 +67,8 @@ class LLMCallbackHandler(BaseCallbackHandler):
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.conversation_message_task.save_message(self.llm_message)

View File

@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self._current_chain_result = None
self._current_chain_message = None
self.conversation_message_task = conversation_message_task
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
llm_constant.agent_model_name,
conversation_message_task
)
self.agent_callback = None
def clear_chain_results(self) -> None:
self._current_chain_result = None
self._current_chain_message = None
self.agent_loop_gather_callback_handler.current_chain = None
if self.agent_callback:
self.agent_callback.current_chain = None
@property
def always_verbose(self) -> bool:
@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
if self.agent_callback:
self.agent_callback.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""

View File

@@ -1,32 +0,0 @@
from typing import Optional
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain
class ChainBuilder:
@classmethod
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
return ToolChain(
tool=tool,
input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'),
callbacks=[DifyStdOutCallbackHandler()]
)
@classmethod
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
SensitiveWordAvoidanceChain]:
sensitive_words = tool_config.get("words", "")
if tool_config.get("enabled", False) \
and sensitive_words:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)
return None

View File

@@ -1,111 +0,0 @@
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import root_validator
from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown
class Route(NamedTuple):
destination: Optional[str]
next_inputs: Dict[str, Any]
class LLMRouterChain(Chain):
"""A router chain that uses an LLM chain to perform routing."""
llm_chain: LLMChain
"""LLM chain used to perform routing"""
@root_validator()
def validate_prompt(cls, values: dict) -> dict:
prompt = values["llm_chain"].prompt
if prompt.output_parser is None:
raise ValueError(
"LLMRouterChain requires base llm_chain prompt to have an output"
" parser that converts LLM text output to a dictionary with keys"
" 'destination' and 'next_inputs'. Received a prompt with no output"
" parser."
)
return values
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return self.llm_chain.input_keys
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
super()._validate_outputs(outputs)
if not isinstance(outputs["next_inputs"], dict):
raise ValueError
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
output = cast(
Dict[str, Any],
self.llm_chain.predict_and_parse(**inputs),
)
return output
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
) -> LLMRouterChain:
"""Convenience constructor."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
@property
def output_keys(self) -> List[str]:
return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any]) -> Route:
result = self(inputs)
return Route(result["destination"], result["next_inputs"])
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
"""Parser for output of router chain int he multi-prompt chain."""
default_destination: str = "DEFAULT"
next_inputs_type: Type = str
next_inputs_inner_key: str = "input"
def parse(self, text: str) -> Dict[str, Any]:
try:
expected_keys = ["destination", "next_inputs"]
parsed = parse_and_check_json_markdown(text, expected_keys)
if not isinstance(parsed["destination"], str):
raise ValueError("Expected 'destination' to be a string.")
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
raise ValueError(
f"Expected 'next_inputs' to be {self.next_inputs_type}."
)
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
if (
parsed["destination"].strip().lower()
== self.default_destination.lower()
):
parsed["destination"] = None
else:
parsed["destination"] = parsed["destination"].strip()
return parsed
except Exception as e:
raise OutputParserException(
f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
)

View File

@@ -1,110 +0,0 @@
from typing import Optional, List, cast
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask
from extensions.ext_database import db
from models.dataset import Dataset
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
chains = []
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
# agent mode
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
chains += tool_chains
if chains_output_key:
final_output_key = chains_output_key
if len(chains) == 0:
return None
for chain in chains:
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
chains=chains,
input_variables=[first_input_key],
output_variables=[final_output_key],
memory=memory, # only for use the memory prompt input key
)
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
# agent mode
chains = []
if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', [])
pre_fixed_chains = []
# agent_tools = []
datasets = []
for tool in tools:
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
if tool_type == 'sensitive-word-avoidance':
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
if chain:
pre_fixed_chains.append(chain)
elif tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == tool_config.get("id")
).first()
if dataset:
datasets.append(dataset)
# add pre-fixed chains
chains += pre_fixed_chains
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)
final_output_key = cls.get_chains_output_key(chains)
return chains, final_output_key
@classmethod
def get_chains_output_key(cls, chains: List[Chain]):
if len(chains) > 0:
return chains[-1].output_keys[0]
return None

View File

@@ -1,198 +0,0 @@
import math
import re
from typing import Mapping, List, Dict, Any, Optional
from langchain import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_index_tool import DatasetTool
from models.dataset import Dataset, DatasetProcessRule
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
what the prompt is best suited for. You may also revise the original input if you \
think that revising it will ultimately lead to a better response from the language \
model.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like, \
no any other string out of markdown code snippet:
```json
{{{{
"destination": string \\ name of the prompt to use or "DEFAULT"
"next_inputs": string \\ a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any \
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT >>
"""
class MultiDatasetRouterChain(Chain):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, DatasetTool]
"""Map of name to candidate chains that inputs can be routed to."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
"""
return self.router_chain.input_keys
@property
def output_keys(self) -> List[str]:
return ["text"]
@classmethod
def from_datasets(
cls,
tenant_id: str,
datasets: List[Dataset],
conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=1024,
callbacks=[DifyStdOutCallbackHandler()]
)
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {}
for dataset in datasets:
# fulfill description when it is empty
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
continue
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=k,
dataset=dataset,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
dataset_tools[str(dataset.id)] = dataset_tool
return cls(
router_chain=router_chain,
dataset_tools=dataset_tools,
**kwargs,
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if len(self.dataset_tools) == 0:
return {"text": ''}
elif len(self.dataset_tools) == 1:
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
route = self.router_chain.route(inputs)
destination = ''
if route.destination:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, route.destination, re.IGNORECASE)
if match:
destination = match.group()
if not destination:
return {"text": ''}
elif destination in self.dataset_tools:
return {"text": self.dataset_tools[destination].run(
route.next_inputs['input']
)}
else:
raise ValueError(
f"Received invalid destination chain name '{destination}'"
)

View File

@@ -1,51 +0,0 @@
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.tools import BaseTool
class ToolChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
tool: BaseTool
@property
def _chain_type(self) -> str:
return "tool_chain"
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key]
output = self.tool.run(input, self.verbose)
return {self.output_key: output}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose)
return {self.output_key: output}

View File

@@ -1,4 +1,5 @@
import logging
import re
from typing import Optional, List, Union, Tuple
from langchain.base_language import BaseLanguageModel
@@ -8,30 +9,31 @@ from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.fake import FakeLLM
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.model import App, AppModelConfig, Account, Conversation, Message
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
class Completion:
@classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
"""
errors: ProviderTokenNotInitError
"""
@@ -69,18 +71,33 @@ class Completion:
streaming=streaming
)
# build main chain include agent
main_chain = MainChainBuilder.to_langchain_components(
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
app_model_config=app_model_config
)
chain_output = ''
if main_chain:
chain_output = main_chain.run(query)
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# run the final llm
try:
@@ -90,7 +107,7 @@ class Completion:
app_model_config=app_model_config,
query=query,
inputs=inputs,
chain_output=chain_output,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory,
streaming=streaming
@@ -105,9 +122,20 @@ class Completion:
@classmethod
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
chain_output: str,
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
final_llm = FakeLLM(response=agent_execute_result.output,
origin_llm=agent_execute_result.configuration.llm,
streaming=streaming)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
response = final_llm.generate([[HumanMessage(content=query)]])
return response
final_llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict,
@@ -118,10 +146,11 @@ class Completion:
prompt, stop_words = cls.get_main_llm_prompt(
mode=mode,
llm=final_llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=chain_output,
agent_execute_result=agent_execute_result,
memory=memory
)
@@ -129,6 +158,7 @@ class Completion:
cls.recale_llm_max_tokens(
final_llm=final_llm,
model=app_model_config.model_dict,
prompt=prompt,
mode=mode
)
@@ -138,41 +168,31 @@ class Completion:
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
# disable template string in query
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
# if query_params:
# for query_param in query_params:
# if query_param not in inputs:
# inputs[query_param] = '{{' + query_param + '}}'
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following CONTEXT as your learned knowledge:
[CONTEXT]
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
[END CONTEXT]
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if chain_output else "")
""" if agent_execute_result else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{{query}}\n"
)
if chain_output:
inputs['context'] = chain_output
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
# if context_params:
# for context_param in context_params:
# if context_param not in inputs:
# inputs[context_param] = '{{' + context_param + '}}'
if agent_execute_result:
inputs['context'] = agent_execute_result.output
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
@@ -202,12 +222,13 @@ And answer according to the language of the user's question.
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if chain_output:
human_inputs['context'] = chain_output
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT]
if agent_execute_result:
human_inputs['context'] = agent_execute_result.output
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
[END CONTEXT]
</context>
When answer to user:
- If you don't know, just say that you don't know.
@@ -219,7 +240,7 @@ And answer according to the language of the user's question.
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\nHuman: {{query}}\nAI: "
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
@@ -228,20 +249,17 @@ And answer according to the language of the user's question.
inputs=human_inputs
)
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
- memory.llm.max_tokens - curr_message_tokens
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
model_name = model['name']
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = llm_constant.max_context_token_length[model_name] \
- max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
# disable template string in query
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
# if histories_params:
# for histories_param in histories_params:
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" + histories
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>\n"
human_message_prompt += histories + "\n</histories>"
human_message_prompt += query_prompt
@@ -253,10 +271,13 @@ And answer according to the language of the user's question.
messages.append(human_message)
return messages, ['\nHuman:']
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return messages, ['\nHuman:', '</histories>']
@classmethod
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
def get_llm_callbacks(cls, llm: BaseLanguageModel,
streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
@@ -267,8 +288,7 @@ And answer according to the language of the user's question.
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> \
str:
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
@@ -307,17 +327,19 @@ And answer according to the language of the user's question.
model=app_model_config.model_dict
)
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens
model_name = app_model_config.model_dict.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=None,
agent_execute_result=None,
memory=None
)
@@ -332,16 +354,17 @@ And answer according to the language of the user's question.
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
prompt: Union[str, List[BaseMessage]], mode: str):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
max_tokens = final_llm.max_tokens
model_name = model.get("name")
model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = model.get("completion_params").get('max_tokens')
if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt)
else:
prompt_tokens = final_llm.get_messages_tokens(prompt)
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
@@ -350,9 +373,10 @@ And answer according to the language of the user's question.
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
llm: StreamableOpenAI = LLMBuilder.to_llm(
llm = LLMBuilder.to_llm_from_model(
tenant_id=app.tenant_id,
model_name='gpt-3.5-turbo',
model=app_model_config.model_dict,
streaming=streaming
)
@@ -360,10 +384,12 @@ And answer according to the language of the user's question.
original_prompt, _ = cls.get_main_llm_prompt(
mode="completion",
llm=llm,
model=app_model_config.model_dict,
pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs,
chain_output=None,
agent_execute_result=None,
memory=None
)
@@ -390,6 +416,7 @@ And answer according to the language of the user's question.
cls.recale_llm_max_tokens(
final_llm=llm,
model=app_model_config.model_dict,
prompt=prompt,
mode='completion'
)

View File

@@ -1,6 +1,8 @@
from _decimal import Decimal
models = {
'claude-instant-1': 'anthropic', # 100,000 tokens
'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens
@@ -10,10 +12,13 @@ models = {
'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
'whisper-1': 'openai'
}
max_context_token_length = {
'claude-instant-1': 100000,
'claude-2': 100000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
@@ -23,17 +28,21 @@ max_context_token_length = {
'text-curie-001': 2049,
'text-babbage-001': 2049,
'text-ada-001': 2049,
'text-embedding-ada-002': 8191
'text-embedding-ada-002': 8191,
}
models_by_mode = {
'chat': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
],
'completion': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
@@ -52,6 +61,14 @@ models_by_mode = {
model_currency = 'USD'
model_prices = {
'claude-instant-1': {
'prompt': Decimal('0.00163'),
'completion': Decimal('0.00551'),
},
'claude-2': {
'prompt': Decimal('0.01102'),
'completion': Decimal('0.03268'),
},
'gpt-4': {
'prompt': Decimal('0.03'),
'completion': Decimal('0.06'),

View File

@@ -52,11 +52,11 @@ class ConversationMessageTask:
message=self.message,
conversation=self.conversation,
chain_pub=False, # disabled currently
agent_thought_pub=False # disabled currently
agent_thought_pub=True
)
def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name
override_model_configs = None
@@ -69,6 +69,7 @@ class ConversationMessageTask:
"suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list,
}
@@ -89,7 +90,7 @@ class ConversationMessageTask:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_messages_tokens([system_message])
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
if not self.conversation:
self.is_new_conversation = True
@@ -185,6 +186,7 @@ class ConversationMessageTask:
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})
@@ -206,7 +208,28 @@ class ConversationMessageTask:
self._pub_handler.pub_chain(message_chain)
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str,
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
message=agent_loop.prompt,
answer=agent_loop.completion,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_thought)
db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_thought)
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
@@ -221,34 +244,18 @@ class ConversationMessageTask:
agent_answer_unit_price
)
message_agent_loop = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
observation=agent_loop.tool_output,
tool_process_data='', # currently not support
message=agent_loop.prompt,
message_token=loop_message_tokens,
message_unit_price=agent_message_unit_price,
answer=agent_loop.completion,
answer_token=loop_answer_tokens,
answer_unit_price=agent_answer_unit_price,
latency=agent_loop.latency,
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
total_price=loop_total_price,
currency=llm_constant.model_currency,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_loop)
message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = agent_message_unit_price
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = agent_answer_unit_price
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = llm_constant.model_currency
db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_loop)
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery(
dataset_id=dataset_query_obj.dataset_id,
@@ -345,16 +352,14 @@ class PubHandler:
content = {
'event': 'agent_thought',
'data': {
'id': message_agent_thought.id,
'task_id': self._task_id,
'message_id': self._message.id,
'chain_id': message_agent_thought.message_chain_id,
'agent_thought_id': message_agent_thought.id,
'position': message_agent_thought.position,
'thought': message_agent_thought.thought,
'tool': message_agent_thought.tool,
'tool_input': message_agent_thought.tool_input,
'observation': message_agent_thought.observation,
'answer': message_agent_thought.answer,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
@@ -387,6 +392,15 @@ class PubHandler:
def _is_stopped(self):
return redis_client.get(self._stopped_cache_key) is not None
@classmethod
def ping(cls, user: Union[Account | EndUser], task_id: str):
content = {
'event': 'ping'
}
channel = cls.generate_channel_name(user, task_id)
redis_client.publish(channel, json.dumps(content))
@classmethod
def stop(cls, user: Union[Account | EndUser], task_id: str):
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)

View File

@@ -1,7 +1,8 @@
import tempfile
from pathlib import Path
from typing import List, Union
from typing import List, Union, Optional
import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
@@ -13,6 +14,9 @@ from core.data_loader.loader.pdf import PdfLoader
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class FileExtractor:
@classmethod
@@ -22,22 +26,41 @@ class FileExtractor:
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
delimiter = '\n'
if input_file.suffix == '.xlsx':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return cls.load_from_file(file_path, return_text, upload_file)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
return cls.load_from_file(file_path, return_text)
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
input_file = Path(file_path)
delimiter = '\n'
if input_file.suffix == '.xlsx':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

View File

@@ -4,6 +4,7 @@ from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
text_embeddings.extend(embedding_results)
return text_embeddings
@handle_openai_exceptions
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists

View File

@@ -23,6 +23,10 @@ class LLMGenerator:
@classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer):
prompt = CONVERSATION_TITLE_PROMPT
if len(query) > 2000:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
@@ -52,7 +56,17 @@ class LLMGenerator:
if not message.answer:
continue
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
context += message_qa_text

View File

@@ -17,7 +17,7 @@ class IndexBuilder:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)

View File

@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
"""
description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
"""

59
api/core/llm/fake.py Normal file
View File

@@ -0,0 +1,59 @@
import time
from typing import List, Optional, Any, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel
class FakeLLM(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
streaming: bool = False
"""Whether to stream the results or not."""
response: str
origin_llm: Optional[BaseLanguageModel] = None
@property
def _llm_type(self) -> str:
return "fake-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
return self.response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"response": self.response}
def get_num_tokens(self, text: str) -> int:
return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
if self.streaming:
for token in output_str:
if run_manager:
run_manager.on_llm_new_token(token)
time.sleep(0.01)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
llm_output = {"token_usage": {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
}}
return ChatResult(generations=[generation], llm_output=llm_output)

View File

@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType
from models.provider import ProviderType, ProviderName
class LLMBuilder:
@@ -32,43 +33,43 @@ class LLMBuilder:
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id)
provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == 'openai':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI
else:
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion':
if provider == 'openai':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI
else:
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI
else:
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
}
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls(
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
**model_extras_kwargs,
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
)
return llm_cls(**model_kwargs)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
@@ -118,14 +119,30 @@ class LLMBuilder:
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str) -> str:
provider = BaseProvider.get_valid_provider(tenant_id)
if not provider:
raise ProviderTokenNotInitError()
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider_name = llm_constant.models[model_name]
if provider_name == 'openai':
# get the default provider (openai / azure_openai) for the tenant
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = azure_openai_provider
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
provider_name = 'openai'
else:
provider_name = provider.provider_name
return provider_name

View File

@@ -1,23 +1,138 @@
from typing import Optional
import json
import logging
from typing import Optional, Union
import anthropic
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
# todo
return []
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
},
]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
"""
return {
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
}
return self.get_provider_api_key(model_id=model_id)
def get_provider_name(self):
return ProviderName.ANTHROPIC
return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'anthropic_api_key': ''
}
if obfuscated:
if not config.get('anthropic_api_key'):
config = {
'anthropic_api_key': ''
}
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
return config
return config
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if quota_used >= quota_limit:
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
f"please configure OpenAI or Azure OpenAI provider first.")
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'anthropic_api_key' not in config:
raise ValueError('anthropic_api_key must be provided.')
chat_llm = ChatAnthropic(
model='claude-instant-1',
anthropic_api_key=config['anthropic_api_key'],
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

View File

@@ -2,6 +2,7 @@ import json
import logging
from typing import Optional, Union
import openai
import requests
from core.llm.provider.base import BaseProvider
@@ -9,32 +10,42 @@ from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
credentials = self.get_credentials(model_id) if not credentials else credentials
url = "{}/openai/deployments?api-version={}".format(
str(credentials.get('openai_api_base')),
str(credentials.get('openai_api_version'))
)
return []
headers = {
"api-key": str(credentials.get('openai_api_key')),
"content-type": "application/json; charset=utf-8"
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
result = response.json()
return [{
'id': deployment['id'],
'name': '{} ({})'.format(deployment['id'], deployment['model'])
} for deployment in result['data'] if deployment['status'] == 'succeeded']
else:
if response.status_code == 401:
raise AzureAuthenticationError()
def check_embedding_model(self, credentials: Optional[dict] = None):
credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
try:
result = openai.Embedding.create(input=['test'],
engine='text-embedding-ada-002',
timeout=60,
api_key=str(credentials.get('openai_api_key')),
api_base=str(credentials.get('openai_api_base')),
api_type='azure',
api_version=str(credentials.get('openai_api_version')))["data"][0][
"embedding"]
except openai.error.AuthenticationError as e:
raise AzureAuthenticationError(str(e))
except openai.error.APIConnectionError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
except openai.error.InvalidRequestError as e:
if e.http_status == 404:
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
"deployment name is exists in Azure AI")
else:
raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code))
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
except openai.error.OpenAIError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
if not isinstance(result, list):
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
@@ -42,9 +53,10 @@ class AzureProvider(BaseProvider):
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 1
config['chunk_size'] = 16
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
@@ -52,16 +64,16 @@ class AzureProvider(BaseProvider):
def get_provider_name(self):
return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key()
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
@@ -70,7 +82,7 @@ class AzureProvider(BaseProvider):
if not config.get('openai_api_key'):
config = {
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
@@ -81,7 +93,6 @@ class AzureProvider(BaseProvider):
return config
def get_token_type(self):
# TODO: change to dict when implemented
return dict
def config_validate(self, config: Union[dict | str]):
@@ -93,34 +104,13 @@ class AzureProvider(BaseProvider):
raise ValueError('Config must be a object.')
if 'openai_api_version' not in config:
config['openai_api_version'] = '2023-03-15-preview'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
models = self.get_models(credentials=config)
if not models:
raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
"'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
"and 'gpt-4', 'gpt-35-turbo-16k' (optional).")
fixed_model_ids = [
'text-davinci-003',
'gpt-35-turbo',
'text-embedding-ada-002'
]
current_model_ids = [model['id'] for model in models]
missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
fixed_model_id not in current_model_ids]
if missing_model_ids:
raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
self.check_embedding_model(credentials=config)
except ValidateFailedError as e:
raise e
except AzureAuthenticationError:
raise ValidateFailedError('Validation failed, please check your API Key.')
except (requests.ConnectionError, requests.RequestException):
raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
except AzureRequestFailedError as ex:
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex:
@@ -133,7 +123,7 @@ class AzureProvider(BaseProvider):
"""
return json.dumps({
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': config['openai_api_base'],
'openai_api_key': self.encrypt_token(config['openai_api_key'])
})

View File

@@ -2,7 +2,7 @@ import base64
from abc import ABC, abstractmethod
from typing import Optional, Union
from core import hosted_llm_credentials
from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db
from libs import rsa
@@ -14,15 +14,18 @@ class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
provider = self.get_provider(prefer_custom)
provider = self.get_provider(only_custom)
if not provider:
raise ProviderTokenNotInitError()
raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
@@ -38,18 +41,19 @@ class BaseProvider(ABC):
else:
return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
If both CUSTOM and System providers exist.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
@@ -58,39 +62,31 @@ class BaseProvider(ABC):
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
custom_provider = None
system_provider = None
providers = query.order_by(Provider.provider_type.asc()).all()
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
custom_provider = provider
return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
system_provider = provider
return provider
if custom_provider:
return custom_provider
elif system_provider:
return system_provider
else:
return None
return None
def get_hosted_credentials(self) -> str:
if self.get_provider_name() != ProviderName.OPENAI:
raise ProviderTokenNotInitError()
def get_hosted_credentials(self) -> Union[str | dict]:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError()
return hosted_llm_credentials.openai.api_key
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key()
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = ''

View File

@@ -31,11 +31,11 @@ class LLMProviderService:
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
return self.provider.get_provider(prefer_custom)
def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]):
"""

View File

@@ -4,6 +4,8 @@ from typing import Optional, Union
import openai
from openai.error import AuthenticationError, OpenAIError
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return hosted_llm_credentials.openai.api_key

View File

@@ -1,14 +1,20 @@
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import BaseMessage, LLMResult, ChatResult, ChatGeneration
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Tuple, Union
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureChatOpenAI(AzureChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@@ -46,30 +52,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
@handle_llm_exceptions
@handle_openai_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
@@ -79,12 +62,58 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(messages, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
model_kwargs = {
'top_p': params.get('top_p', 1),
'frequency_penalty': params.get('frequency_penalty', 0),
'presence_penalty': params.get('presence_penalty', 0),
}
del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call["arguments"] += _function_call["arguments"]
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)

View File

@@ -1,16 +1,20 @@
from langchain.callbacks.manager import Callbacks
from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any
from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureOpenAI(AzureOpenAI):
openai_api_type: str = "azure"
openai_api_version: str = ""
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@@ -50,7 +54,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
@handle_openai_exceptions
def generate(
self,
prompts: List[str],
@@ -60,12 +64,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
return params

View File

@@ -0,0 +1,62 @@
from typing import List, Optional, Any, Dict
from httpx import Timeout
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
class StreamableChatAnthropic(ChatAnthropic):
"""
Wrapper around Anthropic's large language model.
"""
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
@root_validator()
def prepare_params(cls, values: Dict) -> Dict:
values['model_name'] = values.get('model')
values['max_tokens'] = values.get('max_tokens_to_sample')
return values
@handle_anthropic_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
params['model'] = params.get('model_name')
del params['model_name']
params['max_tokens_to_sample'] = params.get('max_tokens')
del params['max_tokens']
del params['frequency_penalty']
del params['presence_penalty']
return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text

View File

@@ -3,14 +3,18 @@ import os
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Union, Tuple
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableChatOpenAI(ChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@@ -48,30 +52,7 @@ class StreamableChatOpenAI(ChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
@handle_llm_exceptions
@handle_openai_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
@@ -81,12 +62,18 @@ class StreamableChatOpenAI(ChatOpenAI):
) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(messages, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
model_kwargs = {
'top_p': params.get('top_p', 1),
'frequency_penalty': params.get('frequency_penalty', 0),
'presence_penalty': params.get('presence_penalty', 0),
}
del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params

View File

@@ -2,14 +2,18 @@ import os
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Any, Mapping
from typing import Optional, List, Dict, Any, Mapping, Union, Tuple
from langchain import OpenAI
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableOpenAI(OpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@@ -49,7 +53,7 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
@handle_openai_exceptions
def generate(
self,
prompts: List[str],
@@ -59,12 +63,6 @@ class StreamableOpenAI(OpenAI):
) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop, callbacks, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
return params

View File

@@ -1,6 +1,7 @@
import openai
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName
from core.llm.error_handle_wraps import handle_llm_exceptions
from core.llm.provider.base import BaseProvider
@@ -13,7 +14,7 @@ class Whisper:
self.client = openai.Audio
self.credentials = provider.get_credentials()
@handle_llm_exceptions
@handle_openai_exceptions
def transcribe(self, file):
return self.client.transcribe(
model='whisper-1',

View File

@@ -0,0 +1,27 @@
import logging
from functools import wraps
import anthropic
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_anthropic_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except anthropic.APIConnectionError as e:
logging.exception("Failed to connect to Anthropic API.")
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
except anthropic.RateLimitError:
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
except anthropic.AuthenticationError as e:
raise LLMAuthorizationError(f"Anthropic: {e.message}")
except anthropic.BadRequestError as e:
raise LLMBadRequestError(f"Anthropic: {e.message}")
except anthropic.APIStatusError as e:
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
return wrapper

View File

@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
LLMBadRequestError
def handle_llm_exceptions(func):
def handle_openai_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper
def handle_llm_exceptions_async(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Dict, Union
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
return chat_messages
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
return chat_messages

View File

@@ -0,0 +1,301 @@
import math
from typing import Optional
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatOpenAI
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from pydantic import BaseModel, Field
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from libs import helper
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
class OrchestratorRuleParser:
"""Parse the orchestrator rule to entities."""
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
self.tenant_id = tenant_id
self.app_model_config = app_model_config
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
self.dataset_retrieve_model_name = "gpt-3.5-turbo"
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
-> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict:
return None
agent_mode_config = self.app_model_config.agent_mode_dict
model_dict = self.app_model_config.model_dict
chain = None
if agent_mode_config and agent_mode_config.get('enabled'):
tool_configs = agent_mode_config.get('tools', [])
agent_model_name = model_dict.get('name', 'gpt-4')
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_name=agent_model_name,
conversation_message_task=conversation_message_task
)
chain_callback.agent_callback = agent_callback
agent_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=agent_model_name,
temperature=0,
max_tokens=1500,
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
)
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
# only OpenAI chat model (include Azure) support function call, use ReACT instead
if not isinstance(agent_llm, ChatOpenAI) \
and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
planning_strategy = PlanningStrategy.REACT
summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=self.agent_summary_model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
tools = self.to_tools(
tool_configs=tool_configs,
conversation_message_task=conversation_message_task,
model_name=self.agent_summary_model_name,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
)
if len(tools) == 0:
return None
dataset_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=self.dataset_retrieve_model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
llm=agent_llm,
tools=tools,
summary_llm=summary_llm,
dataset_llm=dataset_llm,
memory=memory,
callbacks=[chain_callback, agent_callback],
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate"
)
return AgentExecutor(agent_configuration)
return chain
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param kwargs:
:return:
"""
if not self.app_model_config.sensitive_word_avoidance_dict:
return None
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callbacks=callbacks,
**kwargs
)
return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
model_name: str, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param rest_tokens:
:param tool_configs: app agent tool configs
:param model_name:
:param conversation_message_task:
:param callbacks:
:return:
"""
tools = []
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
continue
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(model_name)
elif tool_type == "google_search":
tool = self.to_google_search_tool()
elif tool_type == "wikipedia":
tool = self.to_wikipedia_tool()
elif tool_type == "current_datetime":
tool = self.to_current_datetime_tool()
if tool:
tool.callbacks.extend(callbacks)
tools.append(tool)
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param conversation_message_task:
:return:
"""
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
k=k,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
)
return tool
def to_web_reader_tool(self, model_name: str) -> Optional[BaseTool]:
"""
A tool for reading web pages
:return:
"""
summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
tool = WebReaderTool(
llm=summary_llm,
max_chunk_length=4000,
continue_reading=True,
callbacks=[DifyStdOutCallbackHandler()]
)
return tool
def to_google_search_tool(self) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
return None
tool = Tool(
name="google_search",
description="A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information "
"is not up to date. "
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
args_schema=OptimizedSerpAPIInput,
callbacks=[DifyStdOutCallbackHandler()]
)
return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]:
tool = Tool(
name="current_datetime",
description="A tool when you want to get the current date, time, week, month or year, "
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
func=helper.get_current_datetime,
callbacks=[DifyStdOutCallbackHandler()]
)
return tool
def to_wikipedia_tool(self) -> Optional[BaseTool]:
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
return WikipediaQueryRun(
name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput,
callbacks=[DifyStdOutCallbackHandler()]
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens

View File

@@ -1,87 +0,0 @@
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class DatasetTool(BaseTool):
"""Tool for querying a Dataset."""
dataset: Dataset
k: int = 2
def _run(self, tool_input: str) -> str:
if self.dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=self.dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
tool_input,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))

View File

@@ -0,0 +1,105 @@
import re
from typing import Type
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from pydantic import Field, BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset
class DatasetRetrieverToolInput(BaseModel):
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(BaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: Type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
tenant_id: str
dataset_id: str
k: int = 3
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description.replace('\n', '').replace('\r', '')
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description += '\nID of dataset MUST be ' + dataset.id
return cls(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs
)
def _run(self, dataset_id: str, query: str) -> str:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, dataset_id, re.IGNORECASE)
if match:
dataset_id = match.group()
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
if dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
if self.k > 0:
documents = vector_index.search(
query,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()

View File

@@ -0,0 +1,63 @@
import base64
from abc import ABC, abstractmethod
from typing import Optional
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.tool import ToolProvider, ToolProviderName
class BaseToolProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
@abstractmethod
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_to_func_kwargs(self) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_validate(self, credentials: dict):
raise NotImplementedError
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and tool_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.tool_name == self.get_provider_name().value
)
if must_enabled:
query = query.filter(ToolProvider.is_enabled == True)
return query.first()
def encrypt_token(self, token) -> str:
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
if obfuscated:
return self._obfuscated_token(token)
return token
def _obfuscated_token(self, token: str) -> str:
return token[:6] + '*' * (len(token) - 8) + token[-2:]

View File

@@ -0,0 +1,2 @@
class ToolValidateFailedError(Exception):
description = "Tool Provider Validate failed"

View File

@@ -0,0 +1,77 @@
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from models.tool import ToolProviderName
class SerpAPIToolProvider(BaseToolProvider):
def get_provider_name(self) -> ToolProviderName:
"""
Returns the name of the provider.
:return:
"""
return ToolProviderName.SERPAPI
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider = self.get_provider(must_enabled=True)
if not tool_provider:
return None
credentials = tool_provider.credentials
if not credentials:
return None
if credentials.get('api_key'):
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
return credentials
def credentials_to_func_kwargs(self) -> Optional[dict]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials = self.get_credentials()
if not credentials:
return None
return {
'serpapi_api_key': credentials.get('api_key')
}
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ToolValidateFailedError("SerpAPI api_key is required.")
api_key = credentials.get('api_key')
try:
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
except Exception as e:
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
return credentials

View File

@@ -0,0 +1,43 @@
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
class ToolProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self._init_provider(tenant_id, provider_name)
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
if provider_name == 'serpapi':
return SerpAPIToolProvider(tenant_id)
else:
raise Exception('tool provider {} not found'.format(provider_name))
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for Tool as a dictionary.
:param obfuscated:
:return:
"""
return self.provider.get_credentials(obfuscated)
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return self.provider.credentials_validate(credentials)
def encrypt_credentials(self, credentials: dict):
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
return self.provider.encrypt_credentials(credentials)

View File

@@ -0,0 +1,51 @@
from langchain import SerpAPIWrapper
from pydantic import Field, BaseModel
class OptimizedSerpAPIInput(BaseModel):
query: str = Field(..., description="search query.")
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
@staticmethod
def _process_response(res: dict, num_results: int = 5) -> str:
"""Process response from SerpAPI."""
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
res["answer_box"] = res["answer_box"][0]
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
toret = res["sports_results"]["game_spotlight"]
elif (
"shopping_results" in res.keys()
and "title" in res["shopping_results"][0].keys()
):
toret = res["shopping_results"][:3]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
toret = res["knowledge_graph"]["description"]
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
toret = ""
for result in res["organic_results"][:num_results]:
if "link" in result:
toret += "----------------\nlink: " + result["link"] + "\n"
if "snippet" in result:
toret += "snippet: " + result["snippet"] + "\n"
else:
toret = "No good search result found"
return "search result:\n" + toret

View File

@@ -0,0 +1,419 @@
import hashlib
import json
import os
import re
import site
import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from typing import Type
import requests
from bs4 import BeautifulSoup, NavigableString, Comment, CData
from langchain.base_language import BaseLanguageModel
from langchain.chains.summarize import load_summarize_chain
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article
from pydantic import BaseModel, Field
from regex import regex
from core.data_loader import file_extractor
from core.data_loader.file_extractor import FileExtractor
FULL_TEMPLATE = """
TITLE: {title}
AUTHORS: {authors}
PUBLISH DATE: {publish_date}
TOP_IMAGE_URL: {top_image}
TEXT:
{text}
"""
class WebReaderToolInput(BaseModel):
url: str = Field(..., description="URL of the website to read")
summary: bool = Field(
default=False,
description="When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor: int = Field(
default=0,
description="Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
"The value cannot exceed 24000.",
)
class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "web_reader"
args_schema: Type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \
"there is no need to use."
page_contents: str = None
url: str = None
max_chunk_length: int = 4000
summary_chunk_tokens: int = 4000
summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "", ".", " ", ""]
continue_reading: bool = True
llm: BaseLanguageModel
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try:
if not self.page_contents or self.url != url:
page_contents = get_url(url)
self.page_contents = page_contents
self.url = url
else:
page_contents = self.page_contents
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap,
separators=self.summary_separators
)
texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts]
# only use first 5 docs
if len(docs) > 5:
docs = docs[:5]
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
try:
page_contents = chain.run(docs)
# todo use cache
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
else:
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents
async def _arun(self, url: str) -> str:
raise NotImplementedError
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor: cursor + max_length]
def get_url(url: str) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
if head_response.status_code != 200:
return "URL returned status code {}.".format(head_response.status_code)
# check content-type
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
return FileExtractor.load_from_url(url, return_text=True)
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
a = extract_using_readabilipy(response.text)
if not a['plain_text'] or not a['plain_text'].strip():
return get_url_from_newspaper3k(url)
res = FULL_TEMPLATE.format(
title=a['title'],
authors=a['byline'],
publish_date=a['date'],
top_image="",
text=a['plain_text'] if a['plain_text'] else "",
)
return res
def get_url_from_newspaper3k(url: str) -> str:
a = Article(url)
a.download()
a.parse()
res = FULL_TEMPLATE.format(
title=a.title,
authors=a.authors,
publish_date=a.publish_date,
top_image=a.top_image,
text=a.text,
)
return res
def extract_using_readabilipy(html):
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
f_html.write(html)
f_html.close()
html_path = f_html.name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path = html_path + ".json"
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
with chdir(jsdir):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
with open(article_json_path, "r", encoding="utf-8") as json_file:
input_json = json.loads(json_file.read())
# Deleting files after processing
os.unlink(article_json_path)
os.unlink(html_path)
article_json = {
"title": None,
"byline": None,
"date": None,
"content": None,
"plain_content": None,
"plain_text": None
}
# Populate article fields from readability fields where present
if input_json:
if "title" in input_json and input_json["title"]:
article_json["title"] = input_json["title"]
if "byline" in input_json and input_json["byline"]:
article_json["byline"] = input_json["byline"]
if "date" in input_json and input_json["date"]:
article_json["date"] = input_json["date"]
if "content" in input_json and input_json["content"]:
article_json["content"] = input_json["content"]
article_json["plain_content"] = plain_content(article_json["content"], False, False)
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
if "textContent" in input_json and input_json["textContent"]:
article_json["plain_text"] = input_json["textContent"]
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
return article_json
def find_module_path(module_name):
for package_path in site.getsitepackages():
potential_path = os.path.join(package_path, module_name)
if os.path.exists(potential_path):
return potential_path
return None
@contextmanager
def chdir(path):
"""Change directory in context and return to original on exit"""
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
original_path = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(original_path)
def extract_text_blocks_as_plain_text(paragraph_html):
# Load article as DOM
soup = BeautifulSoup(paragraph_html, 'html.parser')
# Select all lists
list_elements = soup.find_all(['ul', 'ol'])
# Prefix text in all list items with "* " and make lists paragraphs
for list_element in list_elements:
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
list_element.string = plain_items
list_element.name = "p"
# Select all text blocks
text_blocks = [s.parent for s in soup.find_all(string=True)]
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
# Drop empty paragraphs
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
return text_blocks
def plain_text_leaf_node(element):
# Extract all text, stripped of any child HTML elements and normalise it
plain_text = normalise_text(element.get_text())
if plain_text != "" and element.name == "li":
plain_text = "* {}, ".format(plain_text)
if plain_text == "":
plain_text = None
if "data-node-index" in element.attrs:
plain = {"node_index": element["data-node-index"], "text": plain_text}
else:
plain = {"text": plain_text}
return plain
def plain_content(readability_content, content_digests, node_indexes):
# Load article as DOM
soup = BeautifulSoup(readability_content, 'html.parser')
# Make all elements plain
elements = plain_elements(soup.contents, content_digests, node_indexes)
if node_indexes:
# Add node index attributes to nodes
elements = [add_node_indexes(element) for element in elements]
# Replace article contents with plain elements
soup.contents = elements
return str(soup)
def plain_elements(elements, content_digests, node_indexes):
# Get plain content versions of all elements
elements = [plain_element(element, content_digests, node_indexes)
for element in elements]
if content_digests:
# Add content digest attribute to nodes
elements = [add_content_digest(element) for element in elements]
return elements
def plain_element(element, content_digests, node_indexes):
# For lists, we make each item plain text
if is_leaf(element):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text = element.get_text()
# 2. Normalise the extracted text string to a canonical representation
plain_text = normalise_text(plain_text)
# 3. Update element content to be plain text
element.string = plain_text
elif is_text(element):
if is_non_printing(element):
# The simplified HTML may have come from Readability.js so might
# have non-printing text (e.g. Comment or CData). In this case, we
# keep the structure, but ensure that the string is empty.
element = type(element)("")
else:
plain_text = element.string
plain_text = normalise_text(plain_text)
element = type(element)(plain_text)
else:
# If not a leaf node or leaf type call recursively on child nodes, replacing
element.contents = plain_elements(element.contents, content_digests, node_indexes)
return element
def add_node_indexes(element, node_index="0"):
# Can't add attributes to string types
if is_text(element):
return element
# Add index to current element
element["data-node-index"] = node_index
# Add index to child elements
for local_idx, child in enumerate(
[c for c in element.contents if not is_text(c)], start=1):
# Can't add attributes to leaf string types
child_index = "{stem}.{local}".format(
stem=node_index, local=local_idx)
add_node_indexes(child, node_index=child_index)
return element
def normalise_text(text):
"""Normalise unicode and whitespace."""
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
text = strip_control_characters(text)
text = normalise_unicode(text)
text = normalise_whitespace(text)
return text
def strip_control_characters(text):
"""Strip out unicode control characters which might break the parsing."""
# Unicode control characters
# [Cc]: Other, Control [includes new lines]
# [Cf]: Other, Format
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
retained_chars = ['\t', '\n', '\r', '\f']
# Remove non-printing control characters
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
def normalise_unicode(text):
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
def normalise_whitespace(text):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text = regex.sub(r"\s+", " ", text)
# Remove leading and trailing whitespace
text = text.strip()
return text
def is_leaf(element):
return (element.name in ['p', 'li'])
def is_text(element):
return isinstance(element, NavigableString)
def is_non_printing(element):
return any(isinstance(element, _e) for _e in [Comment, CData])
def add_content_digest(element):
if not is_text(element):
element["data-content-digest"] = content_digest(element)
return element
def content_digest(element):
if is_text(element):
# Hash
trimmed_string = element.string.strip()
if trimmed_string == "":
digest = ""
else:
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
else:
contents = element.contents
num_contents = len(contents)
if num_contents == 0:
# No hash when no child elements exist
digest = ""
elif num_contents == 1:
# If single child, use digest of child
digest = content_digest(contents[0])
else:
# Build content digest from the "non-empty" digests of child nodes
digest = hashlib.sha256()
child_digests = list(
filter(lambda x: x != "", [content_digest(content) for content in contents]))
for child in child_digests:
digest.update(child.encode('utf-8'))
digest = digest.hexdigest()
return digest

View File

@@ -1,4 +1,7 @@
from flask import current_app
from events.tenant_event import tenant_was_updated
from models.provider import ProviderName
from services.provider_service import ProviderService
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs):
tenant = sender
if tenant.status == 'normal':
ProviderService.create_system_provider(tenant)
ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)

View File

@@ -1,4 +1,7 @@
from flask import current_app
from events.tenant_event import tenant_was_created
from models.provider import ProviderName
from services.provider_service import ProviderService
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs):
tenant = sender
if tenant.status == 'normal':
ProviderService.create_system_provider(tenant)
ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)

View File

@@ -153,3 +153,9 @@ def get_remote_ip(request):
def generate_text_hash(text: str) -> str:
hash_text = str(text) + 'None'
return sha256(hash_text.encode()).hexdigest()
def get_current_datetime(type: str) -> str:
# get current time
current_time = datetime.utcnow()
return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")

View File

@@ -0,0 +1,32 @@
"""add is_universal in apps
Revision ID: 2beac44e5f5f
Revises: d3d503a3471c
Create Date: 2023-07-07 12:11:29.156057
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2beac44e5f5f'
down_revision = 'a5b56fb053ef'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('is_universal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.drop_column('is_universal')
# ### end Alembic commands ###

View File

@@ -0,0 +1,44 @@
"""add tool providers
Revision ID: 7ce5a52e4eee
Revises: 2beac44e5f5f
Create Date: 2023-07-10 10:26:50.074515
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '7ce5a52e4eee'
down_revision = '2beac44e5f5f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_providers',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
sa.Column('tool_name', sa.String(length=40), nullable=False),
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
sa.Column('is_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('sensitive_word_avoidance')
op.drop_table('tool_providers')
# ### end Alembic commands ###

Some files were not shown because too many files have changed in this diff Show More