Compare commits

...

67 Commits

Author SHA1 Message Date
takatost
5d48406d64 feat: bump version to 0.3.29 (#1462) 2023-11-06 06:55:17 -06:00
takatost
2b2dbabc11 fix: prompt variables validate when using external data tools (#1465) 2023-11-06 06:31:41 -06:00
zxhlyh
13b64bc55a fix: refresh api-based-extension (#1464) 2023-11-06 20:29:41 +08:00
zxhlyh
279f099ba0 fix: chat style (#1463) 2023-11-06 20:11:55 +08:00
zxhlyh
32747641e4 feat: add api-based extension & external data tool & moderation (#1459) 2023-11-06 19:36:32 +08:00
Garfield Dai
db43ed6f41 feat: add api-based extension & external data tool & moderation backend (#1403)
Co-authored-by: takatost <takatost@gmail.com>
2023-11-06 19:36:16 +08:00
YiLi
7699621983 fix: Use correct typehint for return values (#1454)
Co-authored-by: lethe <lethe>
2023-11-06 04:50:51 -06:00
takatost
4dfbcd0b4e feat: support chatglm_turbo model #1443 (#1460) 2023-11-06 04:33:05 -06:00
crazywoola
a9ee18300e fix: service suggested api (#1452) 2023-11-04 19:59:14 +08:00
dependabot[bot]
b4861d2b5c chore(deps): bump word-wrap from 1.2.3 to 1.2.5 in /web (#1440)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-01 11:26:25 +08:00
dependabot[bot]
913f2b84a6 chore(deps-dev): bump postcss from 8.4.24 to 8.4.31 in /web (#1439)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-01 11:24:43 +08:00
dependabot[bot]
cc89933d8f chore(deps): bump crypto-js from 4.1.1 to 4.2.0 in /web (#1437)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-01 11:24:33 +08:00
dependabot[bot]
a14ea6582d chore(deps): bump semver from 5.7.1 to 5.7.2 in /web (#1436)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-01 11:24:24 +08:00
takatost
076f3289d2 feat: add spark v3.0 llm support (#1434) 2023-10-31 03:13:11 -05:00
crazywoola
518083dfe0 fix: metadata not saved (#1429) 2023-10-30 14:39:15 +08:00
crazywoola
2b366bb321 fix: delete app and setting modal is not wokring in firefox (#1427) 2023-10-29 14:22:05 +08:00
Hickays
292d4c077a fix: Add icons for apps in "Related apps list" (#1425) 2023-10-27 17:55:38 +08:00
zxhlyh
fc4c03640d fix: provider delete api key modal z-index (#1416) 2023-10-26 10:35:03 +08:00
Charlie.Wei
985253197f mermaid front-end rendering initialization exception handling logic o… (#1407) 2023-10-26 10:19:04 +08:00
Hickays
48b4249790 fix: workspace app avatar is abnormal (#1411) 2023-10-26 10:18:38 +08:00
takatost
fb64fcb271 feat: upgrade xinference-client to 0.5.4 (#1402) 2023-10-23 05:49:32 -05:00
takatost
41e452dcc5 fix: hex problem (#1395) 2023-10-22 04:15:54 -05:00
yangbo.zhou
d218c66e25 Added diagram picture file for docker-compose yaml file visualization. (#1374) 2023-10-22 09:55:31 +08:00
Panmuse
e173b1cb2a Update README_CN.md (#1390) 2023-10-21 20:41:26 -05:00
Panmuse
9b598db559 Update README.md (#1389) 2023-10-21 20:41:15 -05:00
takatost
e122d677ad fix: return wrong when init 0 quota in trial provider (#1394) 2023-10-21 14:02:38 -05:00
takatost
4c63cbf5b1 feat: adjust anthropic (#1387) 2023-10-20 02:27:46 -05:00
Charlie.Wei
288705fefd Chrome Dify Chatbot Plug-in (#1378)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2023-10-19 07:54:43 -05:00
Joel
8c4ae98f3d feat: add advanced prompt doc link (#1363) 2023-10-19 17:52:30 +08:00
Joel
08aa367892 feat: add context missing warning (#1384)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-10-19 17:52:14 +08:00
Joel
ff527a0190 fix: not load dataset config (#1381) 2023-10-19 13:55:25 +08:00
zxhlyh
6e05f8ca93 fix: npm run start (#1380) 2023-10-19 11:38:03 +08:00
Joel
6309d070d1 feat: enchance prompt mode copywriting (#1379) 2023-10-19 11:19:34 +08:00
Garfield Dai
fe14130b3c refactor advanced prompt core. (#1350) 2023-10-18 20:02:52 +08:00
wayne.wang
52ebffa857 fix: app config zhipu chatglm_std model, but it still use chatglm_lit… (#1377)
Co-authored-by: wayne.wang <wayne.wang@beibei.com>
2023-10-18 05:07:36 -05:00
zxhlyh
d14f15863d fix: i18n runtime error (#1376) 2023-10-18 16:00:56 +08:00
takatost
7c9b585a47 feat: support weixin ernie-bot-4 and chat mode (#1375) 2023-10-18 02:35:24 -05:00
takatost
c039f4af83 fix: app model config detached in completion thread (#1366) 2023-10-17 08:18:08 -05:00
takatost
07285e5f8b feat: optimize completion model agent (#1364) 2023-10-17 06:54:59 -05:00
Chenglong.li
16d80ebab3 Fix milvus configuration error (#1362)
Signed-off-by: JackLCL <chenglong.li@zilliz.com>
2023-10-17 17:40:40 +08:00
zxhlyh
61e816f24c feat: logo (#1356) 2023-10-16 15:26:25 +08:00
takatost
2feb16d957 feat: bump version to 0.3.28 (#1349) 2023-10-14 11:49:56 -05:00
crazywoola
3043fbe73b remove the suggested api for completion app (#1347) 2023-10-14 10:05:33 -05:00
Hickays
9f99c3f55b fix: modal z-index (#1343) 2023-10-13 05:55:03 -05:00
Joel
a07a6d8c26 feat: switch to generation model set default stop word (#1341) 2023-10-13 16:47:22 +08:00
Garfield Dai
695841a3cf Feat/advanced prompt enhancement (#1340) 2023-10-13 16:47:01 +08:00
takatost
3efaa713da feat: use xinference client instead of xinference (#1339) 2023-10-13 02:46:09 -05:00
takatost
9822f687f7 fix: max tokens of OpenAI gpt-3.5-turbo-instruct to 4097 (#1338) 2023-10-13 02:07:07 -05:00
crazywoola
b9d83c04bc fix: modal z-index (#1337) 2023-10-13 14:58:53 +08:00
Charlie.Wei
298ad6782d Add Message Suggested Api (#1326)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-10-13 14:07:32 +08:00
takatost
f4be2b8bcd fix: raise error in minimax stream generate (#1336) 2023-10-12 23:48:28 -05:00
crazywoola
e83e239faf fix: value.join is not a function in log list (#1332) 2023-10-13 11:34:24 +08:00
taokuizu
62bf7f0fc2 fix: new app with template display (#1322) 2023-10-13 10:18:33 +08:00
takatost
7dea485d57 feat: bump version to 0.3.27 (#1331) 2023-10-12 10:37:48 -05:00
zxhlyh
5b9858a8a3 feat: advanced prompt (#1330)
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: Gillian97 <jinling.sunshine@gmail.com>
2023-10-12 23:14:28 +08:00
Garfield Dai
42a5b3ec17 feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
2023-10-12 10:13:10 -05:00
takatost
2d1cb076c6 fix: dataset segment not exist return agent response (#1329) 2023-10-12 04:40:20 -05:00
Jyong
289c93d081 Feat/improve document delete logic (#1325)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-12 13:30:44 +08:00
takatost
c0fe706597 feat: adjust to only build the latest image when pushing a tag. (#1324) 2023-10-11 23:38:07 -05:00
takatost
9cba1c8bf4 fix: retriever_resource missing (#1317) 2023-10-11 14:37:11 -05:00
takatost
cbf095465c feat: remove llm client use (#1316) 2023-10-11 14:02:53 -05:00
KVOJJJin
c007dbdc13 Feat: add document of authorization (#1311) 2023-10-11 08:03:36 -05:00
takatost
ff493d017b fix: minimax tests (#1313) 2023-10-11 07:49:26 -05:00
Jyong
7f6ad9653e Fix/grpc gevent compatible (#1314)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 20:48:35 +08:00
takatost
2851a9f04e feat: optimize minimax llm call (#1312) 2023-10-11 07:17:41 -05:00
takatost
c536f85b2e fix: compatibility issues with the tongyi model. (#1310) 2023-10-11 05:16:26 -05:00
takatost
b1352ff8b7 feat: using random sampling to check if it violates the review mechan… (#1308) 2023-10-11 04:11:20 -05:00
411 changed files with 14414 additions and 2712 deletions

View File

@@ -31,7 +31,7 @@ jobs:
with:
images: langgenius/dify-api
tags: |
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=semver,pattern={{major}}.{{minor}}.{{patch}}

View File

@@ -31,7 +31,7 @@ jobs:
with:
images: langgenius/dify-web
tags: |
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=semver,pattern={{major}}.{{minor}}.{{patch}}

View File

@@ -1,37 +0,0 @@
import os
import re
from zhon.hanzi import punctuation
def has_chinese_characters(text):
for char in text:
if '\u4e00' <= char <= '\u9fff' or char in punctuation:
return True
return False
def check_file_for_chinese_comments(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
for line_number, line in enumerate(file, start=1):
if has_chinese_characters(line):
print(f"Found Chinese characters in {file_path} on line {line_number}:")
print(line.strip())
return True
return False
def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
'prompts.py']
for root, _, files in os.walk("."):
for file in files:
if file.endswith(".py") and file not in excluded_files:
file_path = os.path.join(root, file)
if check_file_for_chinese_comments(file_path):
has_chinese = True
if has_chinese:
raise Exception("Found Chinese characters in Python files. Please remove them.")
if __name__ == "__main__":
main()

View File

@@ -1,31 +0,0 @@
name: Check for Chinese comments
on:
push:
branches:
- 'main'
pull_request:
branches:
- main
jobs:
check-chinese-comments:
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install zhon
- name: Run script to check for Chinese comments
run: |
python .github/workflows/check_no_chinese_comments.py

View File

@@ -37,7 +37,6 @@ https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e19
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
* 600,000 free Claude model tokens to build Claude-powered apps
* 200 free OpenAI queries to build OpenAI-based apps

View File

@@ -36,7 +36,6 @@ https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e19
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
* 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用

View File

@@ -10,7 +10,7 @@
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "api/app.py",
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},

View File

@@ -6,6 +6,9 @@ from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
import logging
import json
@@ -16,7 +19,7 @@ from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
from extensions.ext_database import db
from extensions.ext_login import login_manager
@@ -76,6 +79,7 @@ def create_app(test_config=None) -> Flask:
def initialize_extensions(app):
# Since the application instance is now created, pass it to each Flask
# extension instance to bind it to the Flask application instance (app)
ext_code_based_extension.init()
ext_database.init_app(app)
ext_migrate.init(app, db)
ext_redis.init_app(app)

View File

@@ -57,6 +57,7 @@ DEFAULTS = {
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
'UPLOAD_FILE_BATCH_LIMIT': 5,
'OUTPUT_MODERATION_BUFFER_SIZE': 300
}
@@ -92,7 +93,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.26"
self.CURRENT_VERSION = "0.3.29"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -228,6 +229,9 @@ class Config:
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
# moderation settings
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
class CloudEditionConfig(Config):

View File

@@ -31,6 +31,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -81,6 +82,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -137,10 +139,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
pre_prompt="Please translate the following text into {{target_language}}:\n",
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -169,6 +172,13 @@ demo_model_templates = {
'Italian',
]
}
},{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
@@ -200,6 +210,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
@@ -255,10 +266,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -287,6 +299,13 @@ demo_model_templates = {
"意大利语",
]
}
},{
"paragraph": {
"label": "文本内容",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
@@ -318,6 +337,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,

View File

@@ -6,10 +6,10 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
api = ExternalApi(bp)
# Import other controllers
from . import setup, version, apikey, admin
from . import extension, setup, version, apikey, admin
# Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate

View File

@@ -0,0 +1,25 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')

View File

@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
class IntroductionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('prompt_template', type=str, required=True, location='json')
args = parser.parse_args()
account = current_user
try:
answer = LLMGenerator.generate_introduction(
account.current_tenant_id,
args['prompt_template']
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
return {'introduction': answer}
class RuleGenerateApi(Resource):
@setup_required
@login_required
@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
return rules
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
api.add_resource(RuleGenerateApi, '/rule-generate')

View File

@@ -295,8 +295,8 @@ class MessageSuggestedQuestionApi(Resource):
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=current_user,
message_id=message_id,
user=current_user,
check_enabled=False
)
except MessageNotExistsError:
@@ -329,7 +329,7 @@ class MessageApi(Resource):
message_id = str(message_id)
# get app info
app_model = _get_app(app_id, 'chat')
app_model = _get_app(app_id)
message = db.session.query(Message).filter(
Message.id == message_id,

View File

@@ -27,6 +27,7 @@ class AppParameterApi(InstalledAppResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
}
@marshal_with(parameters_fields)
@@ -42,7 +43,8 @@ class AppParameterApi(InstalledAppResource):
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
}

View File

@@ -0,0 +1,114 @@
from flask_restful import Resource, reqparse, marshal_with
from flask_login import current_user
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from models.api_based_extension import APIBasedExtension
from fields.api_based_extension_fields import api_based_extension_fields
from services.code_based_extension_service import CodeBasedExtensionService
from services.api_based_extension_service import APIBasedExtensionService
class CodeBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('module', type=str, required=True, location='args')
args = parser.parse_args()
return {
'module': args['module'],
'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
}
class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('api_endpoint', type=str, required=True, location='json')
parser.add_argument('api_key', type=str, required=True, location='json')
args = parser.parse_args()
extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id,
name=args['name'],
api_endpoint=args['api_endpoint'],
api_key=args['api_key']
)
return APIBasedExtensionService.save(extension_data)
class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('api_endpoint', type=str, required=True, location='json')
parser.add_argument('api_key', type=str, required=True, location='json')
args = parser.parse_args()
extension_data_from_db.name = args['name']
extension_data_from_db.api_endpoint = args['api_endpoint']
if args['api_key'] != '[__HIDDEN__]':
extension_data_from_db.api_key = args['api_key']
return APIBasedExtensionService.save(extension_data_from_db)
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
APIBasedExtensionService.delete(extension_data_from_db)
return {'result': 'success'}
api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')

View File

@@ -28,6 +28,7 @@ class AppParameterApi(AppApiResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
}
@marshal_with(parameters_fields)
@@ -42,7 +43,8 @@ class AppParameterApi(AppApiResource):
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
}

View File

@@ -183,4 +183,3 @@ api.add_resource(CompletionApi, '/completion-messages')
api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
api.add_resource(ChatApi, '/chat-messages')
api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')

View File

@@ -54,6 +54,7 @@ class ConversationDetailApi(AppApiResource):
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource):
@marshal_with(simple_conversation_fields)

View File

@@ -10,6 +10,8 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from libs.helper import TimestampField, uuid_value
from services.message_service import MessageService
from extensions.ext_database import db
from models.model import Message, EndUser
class MessageListApi(AppApiResource):
@@ -96,5 +98,38 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}
class MessageSuggestedApi(AppApiResource):
def get(self, app_model, end_user, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()
try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=user,
message_id=message_id,
check_enabled=False
)
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success', 'data': questions}
api.add_resource(MessageListApi, '/messages')
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
api.add_resource(MessageSuggestedApi, '/messages/<uuid:message_id>/suggested')

View File

@@ -27,6 +27,7 @@ class AppParameterApi(WebApiResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
}
@marshal_with(parameters_fields)
@@ -41,7 +42,8 @@ class AppParameterApi(WebApiResource):
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
}

View File

@@ -139,7 +139,7 @@ class ChatStopApi(WebApiResource):
return {'result': 'success'}, 200
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@@ -0,0 +1 @@
import core.moderation.base

View File

@@ -2,14 +2,18 @@ import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
@@ -65,17 +73,57 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
@@ -87,7 +135,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -96,11 +144,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
_format_intermediate_steps
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_instance: BaseLLM = None
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
@@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 40
original_max_tokens = self.model_instance.model_kwargs.max_tokens
self.model_instance.model_kwargs.max_tokens = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})
function_call = result.function_call
self.llm.max_tokens = original_max_tokens
self.model_instance.model_kwargs.max_tokens = original_max_tokens
return True if function_call else False
@@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
agent_decision = _parse_ai_message(predicted_message)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
@@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_instance.model_provider.provider_name == 'azure_openai':
model = model_instance.base_model_name
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_instance.base_model_name
tiktoken_ = _import_tiktoken()
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -1,140 +0,0 @@
from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
summary_handler = SummarizerMixin(llm=self.summary_llm)
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, model_instance.client)
model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -1,107 +0,0 @@
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.agents import BaseMultiActionAgent
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
_parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseMultiActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return True if function_call else False
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def get_system_message(cls):
# get current time
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")

View File

@@ -1,10 +1,9 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
@@ -12,6 +11,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -49,7 +50,6 @@ Action:
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool]
class Config:
@@ -93,12 +93,16 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
try:
@@ -108,6 +112,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
elif isinstance(tool_inputs, str):
agent_decision.tool_input = kwargs['input']
else:
agent_decision.return_values['output'] = ''
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
@@ -142,10 +150,65 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -157,17 +220,36 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT:
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
dataset_tools=tools,
**kwargs,
)

View File

@@ -1,19 +1,21 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -52,8 +54,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
summary_model_instance: BaseLLM = None
class Config:
"""Configuration for this pydantic object."""
@@ -95,14 +96,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
try:
@@ -118,7 +119,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_llm:
if len(intermediate_steps) >= 2 and self.summary_model_instance:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
@@ -130,11 +131,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
summary_handler = SummarizerMixin(llm=self.summary_llm)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = summary_handler.predict_new_summary(
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
@@ -144,6 +144,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
def create_prompt(
cls,
@@ -173,10 +185,65 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -188,16 +255,35 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT:
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
@@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentConfiguration(BaseModel):
@@ -64,30 +62,18 @@ class AgentExecutor:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
verbose=True
)
@@ -95,7 +81,6 @@ class AgentExecutor:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
@@ -104,7 +89,6 @@ class AgentExecutor:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True

View File

@@ -1,13 +1,25 @@
import logging
from typing import Any, Dict, List, Union
import threading
import time
from typing import Any, Dict, List, Union, Optional
from flask import Flask, current_app
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage
from pydantic import BaseModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.moderation.base import ModerationOutputsResult, ModerationAction
from core.moderation.factory import ModerationFactory
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class LLMCallbackHandler(BaseCallbackHandler):
@@ -20,6 +32,24 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.start_at = None
self.conversation_message_task = conversation_message_task
self.output_moderation_handler = None
self.init_output_moderation()
def init_output_moderation(self):
app_model_config = self.conversation_message_task.app_model_config
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
self.output_moderation_handler = OutputModerationHandler(
tenant_id=self.conversation_message_task.tenant_id,
app_id=self.conversation_message_task.app.id,
rule=ModerationRule(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config")
),
on_message_replace_func=self.conversation_message_task.on_message_replace
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
@@ -59,10 +89,19 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
completion=response.generations[0][0].text,
public_event=True if self.conversation_message_task.streaming else False
)
else:
self.llm_message.completion = response.generations[0][0].text
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(self.llm_message.completion)
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
@@ -79,23 +118,161 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.save_message(self.llm_message)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
try:
self.conversation_message_task.append_message_text(token)
except ConversationTaskStoppedException as ex:
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
ex = ConversationTaskInterruptException()
self.on_llm_error(error=ex)
raise ex
self.llm_message.completion += token
try:
self.conversation_message_task.append_message_text(token)
self.llm_message.completion += token
if self.output_moderation_handler:
self.output_moderation_handler.append_new_token(token)
except ConversationTaskStoppedException as ex:
self.on_llm_error(error=ex)
raise ex
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
if isinstance(error, ConversationTaskInterruptException):
self.llm_message.completion = self.output_moderation_handler.get_final_output()
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message)
else:
logging.debug("on_llm_error: %s", error)
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
rule: ModerationRule
on_message_replace_func: Any
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
is_final_chunk: bool = False
final_output: Optional[str] = None
class Config:
arbitrary_types_allowed = True
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
})
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logging.error("Moderation Output error: %s", e)
return None

View File

@@ -0,0 +1,36 @@
from typing import List, Dict, Any, Optional
from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import LLMResult, Generation
from langchain.schema.language_model import BaseLanguageModel
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):
model_instance: BaseLLM
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
stop=stop
)
generations = [
[Generation(text=result.content)]
]
return LLMResult(generations=generations)

View File

@@ -1,92 +0,0 @@
import enum
import logging
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
model_instance: BaseLLM
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property
def _chain_type(self) -> str:
return "sensitive_word_avoidance_chain"
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text:
return False
return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
class SensitiveWordAvoidanceError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@@ -1,14 +1,18 @@
import concurrent
import json
import logging
from typing import Optional, List, Union
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, Union, Tuple
from flask import current_app, Flask
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
@@ -16,10 +20,11 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
class Completion:
@@ -30,7 +35,7 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
query = PromptBuilder.process_template(query)
query = PromptTemplateParser.remove_template_variables(query)
memory = None
if conversation:
@@ -78,26 +83,35 @@ class Completion:
)
try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
try:
# process sensitive_word_avoidance
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
except ModerationException as e:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=str(e)
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
inputs = cls.fill_in_inputs_from_external_data_tools(
tenant_id=app.tenant_id,
app_id=app.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
@@ -137,19 +151,110 @@ class Completion:
memory=memory,
fake_response=fake_response
)
except ConversationTaskStoppedException:
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
return
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
return inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDED:
inputs = moderation_result.inputs
query = moderation_result.query
return inputs, query
@classmethod
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
inputs: dict, query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
# Group tools by type and config
grouped_tools = {}
for tool in external_data_tools:
if not tool.get("enabled"):
continue
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
grouped_tools.setdefault(tool_key, []).append(tool)
results = {}
with ThreadPoolExecutor() as executor:
futures = {}
for tools in grouped_tools.values():
# Only query the first tool in each group
first_tool = tools[0]
future = executor.submit(
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool,
inputs, query
)
for tool in tools:
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
tool_key, result = future.result()
if tool_key in grouped_tools:
for tool in grouped_tools[tool_key]:
results[tool['variable']] = result
inputs.update(results)
return inputs
@classmethod
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
with flask_app.app_context():
tool_variable = external_data_tool.get("variable")
tool_type = external_data_tool.get("type")
tool_config = external_data_tool.get("config")
external_data_tool_factory = ExternalDataToolFactory(
name=tool_type,
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
)
# query external data tool
result = external_data_tool_factory.query(
inputs=inputs,
query=query
)
tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True))
return tool_key, result
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod
@@ -159,15 +264,33 @@ class Completion:
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
prompt_transform = PromptTransform()
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
)
model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens(
model_instance=model_instance,
@@ -176,7 +299,7 @@ class Completion:
response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
@@ -227,15 +350,30 @@ class Completion:
if max_tokens is None:
max_tokens = 0
prompt_transform = PromptTransform()
prompt_messages = []
# get prompt without memory and context
prompt_messages, _ = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=None,
memory=None
)
if app_model_config.prompt_type == 'simple':
prompt_messages, _ = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=None,
memory=None,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=None,
memory=None,
model_instance=model_instance
)
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
@@ -266,52 +404,3 @@ class Completion:
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model_config=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
inputs=message.inputs,
query=message.query,
context=None,
memory=None
)
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
prompt_messages = [PromptMessage(content=prompt)]
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming,
model_instance=final_model_instance
)
cls.recale_llm_max_tokens(
model_instance=final_model_instance,
prompt_messages=prompt_messages
)
final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)

View File

@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -74,10 +74,10 @@ class ConversationMessageTask:
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
introduction = prompt_template.format(**prompt_inputs)
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
@@ -150,12 +150,12 @@ class ConversationMessageTask:
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
@@ -163,7 +163,7 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(
self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
@@ -226,15 +226,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
@@ -290,6 +290,10 @@ class ConversationMessageTask:
db.session.commit()
self.retriever_resource = resource
def on_message_replace(self, text: str):
if text is not None:
self._pub_handler.pub_message_replace(text)
def message_end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
@@ -342,6 +346,24 @@ class PubHandler:
self.pub_end()
raise ConversationTaskStoppedException()
def pub_message_replace(self, text: str):
content = {
'event': 'message_replace',
'data': {
'task_id': self._task_id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': str(self._conversation.id)
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_chain(self, message_chain: MessageChain):
if self._chain_pub:
content = {
@@ -443,3 +465,7 @@ class PubHandler:
class ConversationTaskStoppedException(Exception):
pass
class ConversationTaskInterruptException(Exception):
pass

View File

View File

@@ -0,0 +1,62 @@
import os
import requests
from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor:
timeout: (int, int) = (5, 60)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None:
self.api_endpoint = api_endpoint
self.api_key = api_key
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
"""
Request the api.
:param point: the api point
:param params: the request params
:return: the response json
"""
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(self.api_key)
}
url = self.api_endpoint
try:
# proxy support for security
proxies = None
if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"):
proxies = {
'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"),
'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"),
}
response = requests.request(
method='POST',
url=url,
json={
'point': point.value,
'params': params
},
headers=headers,
timeout=self.timeout,
proxies=proxies
)
except requests.exceptions.Timeout:
raise ValueError("request timeout")
except requests.exceptions.ConnectionError:
raise ValueError("request connection error")
if response.status_code != 200:
raise ValueError("request error, status_code: {}, content: {}".format(
response.status_code,
response.text[:100]
))
return response.json()

View File

@@ -0,0 +1,111 @@
import enum
import importlib.util
import json
import logging
import os
from collections import OrderedDict
from typing import Any, Optional
from pydantic import BaseModel
class ExtensionModule(enum.Enum):
MODERATION = 'moderation'
EXTERNAL_DATA_TOOL = 'external_data_tool'
class ModuleExtension(BaseModel):
extension_class: Any
name: str
label: Optional[dict] = None
form_schema: Optional[list] = None
builtin: bool = True
position: Optional[int] = None
class Extensible:
module: ExtensionModule
name: str
tenant_id: str
config: Optional[dict] = None
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
self.tenant_id = tenant_id
self.config = config
@classmethod
def scan_extensions(cls):
extensions = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
current_dir_path = os.path.dirname(current_path)
# traverse subdirectories
for subdir_name in os.listdir(current_dir_path):
if subdir_name.startswith('__'):
continue
subdir_path = os.path.join(current_dir_path, subdir_name)
extension_name = subdir_name
if os.path.isdir(subdir_path):
file_names = os.listdir(subdir_path)
# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
builtin = False
position = None
if '__builtin__' in file_names:
builtin = True
builtin_file_path = os.path.join(subdir_path, '__builtin__')
if os.path.exists(builtin_file_path):
with open(builtin_file_path, 'r') as f:
position = int(f.read().strip())
if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break
if not extension_class:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue
json_data = {}
if not builtin:
if 'schema.json' not in file_names:
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue
json_path = os.path.join(subdir_path, 'schema.json')
json_data = {}
if os.path.exists(json_path):
with open(json_path, 'r') as f:
json_data = json.load(f)
extensions[extension_name] = ModuleExtension(
extension_class=extension_class,
name=extension_name,
label=json_data.get('label'),
form_schema=json_data.get('form_schema'),
builtin=builtin,
position=position
)
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
sorted_extensions = OrderedDict(sorted_items)
return sorted_extensions

View File

@@ -0,0 +1,47 @@
from core.extension.extensible import ModuleExtension, ExtensionModule
from core.external_data_tool.base import ExternalDataTool
from core.moderation.base import Moderation
class Extension:
__module_extensions: dict[str, dict[str, ModuleExtension]] = {}
module_classes = {
ExtensionModule.MODERATION: Moderation,
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
}
def init(self):
for module, module_class in self.module_classes.items():
self.__module_extensions[module.value] = module_class.scan_extensions()
def module_extensions(self, module: str) -> list[ModuleExtension]:
module_extensions = self.__module_extensions.get(module)
if not module_extensions:
raise ValueError(f"Extension Module {module} not found")
return list(module_extensions.values())
def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension:
module_extensions = self.__module_extensions.get(module.value)
if not module_extensions:
raise ValueError(f"Extension Module {module} not found")
module_extension = module_extensions.get(extension_name)
if not module_extension:
raise ValueError(f"Extension {extension_name} not found")
return module_extension
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
module_extension = self.module_extension(module, extension_name)
return module_extension.extension_class
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
form_schema = module_extension.form_schema
# TODO validate form_schema

View File

View File

@@ -0,0 +1 @@
1

View File

@@ -0,0 +1,92 @@
from typing import Optional
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class ApiExternalDataTool(ExternalDataTool):
"""
The api external data tool.
"""
name: str = "api"
"""the unique name of external data tool"""
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
# own validation logic
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
# get params from config
api_based_extension_id = self.config.get("api_based_extension_id")
# get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == self.tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
if not api_based_extension:
raise ValueError("[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid"
.format(self.config.get('variable')))
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=self.tenant_id,
token=api_based_extension.api_key
)
try:
# request api
requestor = APIBasedExtensionRequestor(
api_endpoint=api_based_extension.api_endpoint,
api_key=api_key
)
except Exception as e:
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
self.config.get('variable'),
e
))
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
'app_id': self.app_id,
'tool_variable': self.variable,
'inputs': inputs,
'query': query
})
if 'result' not in response_json:
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
.format(self.config.get('variable')))
return response_json['result']

View File

@@ -0,0 +1,45 @@
from abc import abstractmethod, ABC
from typing import Optional
from core.extension.extensible import Extensible, ExtensionModule
class ExternalDataTool(Extensible, ABC):
"""
The base class of external data tool.
"""
module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL
app_id: str
"""the id of app"""
variable: str
"""the tool variable name of app tool"""
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
super().__init__(tenant_id, config)
self.app_id = app_id
self.variable = variable
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
raise NotImplementedError
@abstractmethod
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
raise NotImplementedError

View File

@@ -0,0 +1,40 @@
from typing import Optional
from core.extension.extensible import ExtensionModule
from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class(
tenant_id=tenant_id,
app_id=app_id,
variable=variable,
config=config
)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param name: the name of external data tool
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
extension_class.validate_config(tenant_id, config)
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
return self.__extension_instance.query(inputs, query)

View File

@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
class LLMGenerator:
@@ -44,78 +43,19 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
max_context_token_length = max_context_token_length if max_context_token_length else 1500
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
if not message.answer:
continue
if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
prompt = JinjaPromptTemplate(
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
input_variables=["histories"],
partial_variables={"format_instructions": format_instructions}
prompt_template = PromptTemplateParser(
template="{{histories}}\n{{format_instructions}}\nquestions:\n"
)
_input = prompt.format_prompt(histories=histories)
prompt = prompt_template.format({
"histories": histories,
"format_instructions": format_instructions
})
try:
model_instance = ModelFactory.get_text_generation_model(
@@ -128,10 +68,10 @@ class LLMGenerator:
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
questions = output_parser.parse(output.content)
except LLMError:
questions = []
@@ -145,19 +85,21 @@ class LLMGenerator:
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
output_parser = RuleConfigGeneratorOutputParser()
prompt = OutLinePromptTemplate(
template=output_parser.get_format_instructions(),
input_variables=["audiences", "hoping_to_solve"],
partial_variables={
"variable": '{variable}',
"lanA": '{lanA}',
"lanB": '{lanB}',
"topic": '{topic}'
},
validate_template=False
prompt_template = PromptTemplateParser(
template=output_parser.get_format_instructions()
)
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
prompt = prompt_template.format(
inputs={
"audiences": audiences,
"hoping_to_solve": hoping_to_solve,
"variable": "{{variable}}",
"lanA": "{{lanA}}",
"lanB": "{{lanB}}",
"topic": "{{topic}}"
},
remove_template_variables=False
)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
@@ -167,10 +109,10 @@ class LLMGenerator:
)
)
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
rule_config = output_parser.parse(output.content)
except LLMError as e:
raise e

View File

@@ -1,4 +1,5 @@
import logging
import random
import openai
@@ -16,19 +17,20 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
if len(text_chunks) == 0:
return True
for text_chunk in chunks:
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
text_chunk = random.choice(text_chunks)
for result in moderation_result.results:
if result['flagged'] is True:
return False
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@@ -167,8 +167,6 @@ class Milvus(VectorStore):
self._init()
@property
def embeddings(self) -> Embeddings:
return self.embedding_func

View File

@@ -11,6 +11,7 @@ from flask import current_app, Flask
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
@@ -79,6 +80,8 @@ class IndexingRunner:
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except ObjectDeletedError:
logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
@@ -276,13 +279,14 @@ class IndexingRunner:
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -372,13 +376,14 @@ class IndexingRunner:
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -582,7 +587,6 @@ class IndexingRunner:
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
"""
@@ -734,6 +738,9 @@ class IndexingRunner:
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
document = DatasetDocument.query.filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedException()
update_params = {
DatasetDocument.indexing_status: after_indexing_status
@@ -781,3 +788,7 @@ class IndexingRunner:
class DocumentIsPausedException(Exception):
pass
class DocumentIsDeletedPausedException(Exception):
pass

View File

@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:

View File

@@ -211,6 +211,9 @@ class ModelProviderFactory:
Provider.quota_type == ProviderQuotaType.TRIAL.value
).first()
if provider.quota_limit == 0:
return None
return provider
no_system_provider = True

View File

@@ -1,8 +1,7 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
class XinferenceEmbedding(BaseEmbedding):

View File

@@ -1,6 +1,6 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel
@@ -9,26 +9,31 @@ class LLMRunResult(BaseModel):
prompt_tokens: int
completion_tokens: int
source: list = None
function_call: dict = None
class MessageType(enum.Enum):
HUMAN = 'human'
USER = 'user'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
type: MessageType = MessageType.USER
content: str = ''
function_call: dict = None
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
additional_kwargs = {}
if message.function_call:
additional_kwargs['function_call'] = message.function_call
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
@@ -39,11 +44,21 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
message_kwargs = {
'content': message.content,
'type': MessageType.ASSISTANT
}
if 'function_call' in message.additional_kwargs:
message_kwargs['function_call'] = message.additional_kwargs['function_call']
prompt_messages.append(PromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
return prompt_messages

View File

@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
@property
def base_model_name(self) -> str:

View File

@@ -37,12 +37,6 @@ class BaichuanModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.

View File

@@ -1,27 +1,18 @@
import json
import os
import re
import time
from abc import abstractmethod
from typing import List, Optional, Any, Union, Tuple
from typing import List, Optional, Any, Union
import decimal
import logging
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from langchain.schema import LLMResult, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.third_party.langchain.llms.fake import FakeLLM
import logging
from extensions.ext_database import db
logger = logging.getLogger(__name__)
@@ -157,8 +148,11 @@ class BaseLLM(BaseProviderModel):
except Exception as ex:
raise self.handle_exceptions(ex)
function_call = None
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
if 'function_call' in result.generations[0][0].message.additional_kwargs:
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
else:
completion_content = result.generations[0][0].text
@@ -191,7 +185,8 @@ class BaseLLM(BaseProviderModel):
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
completion_tokens=completion_tokens,
function_call=function_call
)
@abstractmethod
@@ -227,7 +222,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return:
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -245,7 +240,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -260,7 +255,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
@@ -315,121 +310,8 @@ class BaseLLM(BaseProviderModel):
def support_streaming(self):
return False
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += pre_prompt_content
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)
if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return prompt, stops
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
'prompt/generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
if not model_mode:
model_mode = self.model_mode
@@ -442,16 +324,7 @@ class BaseLLM(BaseProviderModel):
if len(messages) == 0:
return []
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
return to_lc_messages(messages)
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""

View File

@@ -66,15 +66,6 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs

View File

@@ -1,26 +1,23 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax(
return MinimaxChatLLM(
model=self.name,
model_kwargs={
'stream': False
},
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def get_currency(self):
return 'RMB'
@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex
@property
def support_streaming(self):
return True

View File

@@ -33,7 +33,7 @@ MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
@@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM):
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""

View File

@@ -49,15 +49,6 @@ class OpenLLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass

View File

@@ -18,7 +18,6 @@ class TongyiModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
del provider_model_kwargs['max_tokens']
return EnhanceTongyi(
model_name=self.name,
max_retries=1,
@@ -58,7 +57,6 @@ class TongyiModel(BaseLLM):
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
del provider_model_kwargs['max_tokens']
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)

View File

@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin
class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
model=self.name,
streaming=self.streaming,
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@@ -59,15 +59,6 @@ class XinferenceModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass

View File

@@ -16,6 +16,7 @@ class ZhipuAIModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ZhipuAIChatLLM(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,

View File

@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
},
{
'id': 'claude-2',
'name': 'claude-2',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -167,7 +172,7 @@ class AnthropicProvider(BaseModelProvider):
def should_deduct_quota(self):
if hosted_model_providers.anthropic and \
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > 0:
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > -1:
return True
return False

View File

@@ -12,7 +12,7 @@ from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
if credentials['base_model_name'] in [
'gpt-4',
'gpt-4-32k',
@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
return model_list
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name == 'text-davinci-003':
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]
@@ -314,7 +329,7 @@ class AzureOpenAIProvider(BaseModelProvider):
def should_deduct_quota(self):
if hosted_model_providers.azure_openai \
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > 0:
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > -1:
return True
return False

View File

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
@@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
Returns the name of a provider.
"""
return 'baichuan'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
@@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
{
'id': 'baichuan2-53b',
'name': 'Baichuan2-53B',
'mode': ModelMode.CHAT.value,
}
]
else:

View File

@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
return [{
'id': provider_model.model_name,
'name': provider_model.model_name
} for provider_model in provider_models]
provider_model_list = []
for provider_model in provider_models:
provider_model_dict = {
'id': provider_model.model_name,
'name': provider_model.model_name
}
if model_type == ModelType.TEXT_GENERATION:
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
provider_model_list.append(provider_model_dict)
return provider_model_list
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
"""
raise NotImplementedError
@abstractmethod
def _get_text_generation_model_mode(self, model_name) -> str:
"""
get text generation model mode.
:param model_name:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""

View File

@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -11,7 +11,7 @@ class HostedOpenAI(BaseModel):
api_organization: str = None
api_key: str
quota_limit: int = 0
"""Quota limit for the openai hosted model. 0 means unlimited."""
"""Quota limit for the openai hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
@@ -21,14 +21,14 @@ class HostedAzureOpenAI(BaseModel):
api_base: str
api_key: str
quota_limit: int = 0
"""Quota limit for the azure openai hosted model. 0 means unlimited."""
"""Quota limit for the azure openai hosted model. -1 means unlimited."""
class HostedAnthropic(BaseModel):
api_base: str = None
api_key: str
quota_limit: int = 0
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1000000

View File

@@ -5,7 +5,7 @@ import requests
from huggingface_hub import HfApi
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
if credentials['completion_type'] == 'chat_completion':
return ModelMode.CHAT.value
else:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -2,14 +2,15 @@ import json
from json import JSONDecodeError
from typing import Type
from langchain.llms import Minimax
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
from models.provider import ProviderType, ProviderQuotaType
@@ -28,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
{
'id': 'abab5.5-chat',
'name': 'abab5.5-chat',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'abab5-chat',
'name': 'abab5-chat',
'mode': ModelMode.COMPLETION.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -44,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -98,14 +104,14 @@ class MinimaxProvider(BaseModelProvider):
'minimax_api_key': credentials['minimax_api_key'],
}
llm = Minimax(
llm = MinimaxChatLLM(
model='abab5.5-chat',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]
@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name in COMPLETION_MODELS:
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -132,7 +144,7 @@ class OpenAIProvider(BaseModelProvider):
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
@@ -238,7 +250,7 @@ class OpenAIProvider(BaseModelProvider):
def should_deduct_quota(self):
if hosted_model_providers.openai \
and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > 0:
and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > -1:
return True
return False

View File

@@ -3,7 +3,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -6,7 +6,8 @@ import replicate
from replicate.exceptions import ReplicateError
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
ModelMode
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark
@@ -28,17 +28,27 @@ class SparkProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'spark',
'name': 'Spark V1.5',
'id': 'spark-v3',
'name': 'Spark V3.0',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark',
'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -91,7 +101,7 @@ class SparkProvider(BaseModelProvider):
try:
chat_llm = ChatSpark(
model_name='spark-v2',
model_name='spark-v3',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@@ -105,10 +115,10 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
# try spark v1.5 if v2.1 failed
# try spark v2.1 if v3.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
model_name='spark-v2',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@@ -122,10 +132,27 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
# try spark v1.5 if v2.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex

View File

@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
@@ -24,17 +24,22 @@ class TongyiProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'qwen-v1',
'name': 'qwen-v1',
'id': 'qwen-turbo',
'name': 'qwen-turbo',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'qwen-plus-v1',
'name': 'qwen-plus-v1',
'id': 'qwen-plus',
'name': 'qwen-plus',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -58,16 +63,16 @@ class TongyiProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'qwen-v1': 1500,
'qwen-plus-v1': 6500
'qwen-turbo': 6000,
'qwen-plus': 6000
}
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
temperature=KwargRule[float](min=0.01, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.5, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
max_tokens=KwargRule[int](enabled=False, max=model_max_tokens.get(model_name)),
)
@classmethod
@@ -84,7 +89,7 @@ class TongyiProvider(BaseModelProvider):
}
llm = EnhanceTongyi(
model_name='qwen-v1',
model_name='qwen-turbo',
max_retries=1,
**credential_kwargs
)

View File

@@ -2,9 +2,11 @@ import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin
@@ -23,22 +25,33 @@ class WenxinProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.CHAT.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -62,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
@@ -105,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**credential_kwargs
)
llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@@ -2,15 +2,15 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType
@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
@@ -26,21 +26,30 @@ class ZhipuAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm_turbo',
'name': 'chatglm_turbo',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
'mode': ModelMode.CHAT.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -53,6 +62,9 @@ class ZhipuAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -9,7 +9,7 @@
"trial"
],
"quota_unit": "tokens",
"quota_limit": 600000
"quota_limit": 0
},
"model_flexibility": "fixed",
"price_config": {

View File

@@ -22,6 +22,12 @@
"completion": "0.36",
"unit": "0.0001",
"currency": "RMB"
},
"spark-v3": {
"prompt": "0.36",
"completion": "0.36",
"unit": "0.0001",
"currency": "RMB"
}
}
}

View File

@@ -3,5 +3,19 @@
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"qwen-turbo": {
"prompt": "0.012",
"completion": "0.012",
"unit": "0.001",
"currency": "RMB"
},
"qwen-plus": {
"prompt": "0.14",
"completion": "0.14",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@@ -5,6 +5,12 @@
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",

View File

@@ -11,6 +11,12 @@
},
"model_flexibility": "fixed",
"price_config": {
"chatglm_turbo": {
"prompt": "0.005",
"completion": "0.005",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_pro": {
"prompt": "0.01",
"completion": "0.01",

View File

View File

@@ -0,0 +1 @@
3

View File

View File

@@ -0,0 +1,88 @@
from pydantic import BaseModel
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint
from core.helper.encrypter import decrypt_token
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
inputs: dict = {}
query: str = ""
class ModerationOutputParams(BaseModel):
app_id: str = ""
text: str
class ApiModeration(Moderation):
name: str = "api"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
cls._validate_inputs_and_outputs_config(config, False)
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
params = ModerationInputParams(
app_id=self.app_id,
inputs=inputs,
query=query
)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict())
return ModerationInputsResult(**result)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
params = ModerationOutputParams(
app_id=self.app_id,
text=text
)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict())
return ModerationOutputsResult(**result)
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
result = requestor.request(extension_point, params)
return result
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
return extension

113
api/core/moderation/base.py Normal file
View File

@@ -0,0 +1,113 @@
from abc import ABC, abstractmethod
from typing import Optional
from pydantic import BaseModel
from enum import Enum
from core.extension.extensible import Extensible, ExtensionModule
class ModerationAction(Enum):
DIRECT_OUTPUT = 'direct_output'
OVERRIDED = 'overrided'
class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
inputs: dict = {}
query: str = ""
class ModerationOutputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
text: str = ""
class Moderation(Extensible, ABC):
"""
The base class of moderation.
"""
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
super().__init__(tenant_id, config)
self.app_id = app_id
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
raise NotImplementedError
@abstractmethod
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
on the user inputs and return the processed results.
:param inputs: user inputs
:param query: query string (required in chat app)
:return:
"""
raise NotImplementedError
@abstractmethod
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
"""
Moderation for outputs.
When LLM outputs content, the front end will pass the output content (may be segmented)
to this method for sensitive content review, and the output content will be shielded if the review fails.
:param text: LLM output content
:return:
"""
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):
raise ValueError("inputs_config must be a dict")
# outputs_config
outputs_config = config.get("outputs_config")
if not isinstance(outputs_config, dict):
raise ValueError("outputs_config must be a dict")
inputs_config_enabled = inputs_config.get("enabled")
outputs_config_enabled = outputs_config.get("enabled")
if not inputs_config_enabled and not outputs_config_enabled:
raise ValueError("At least one of inputs_config or outputs_config must be enabled")
# preset_response
if not is_preset_response_required:
return
if inputs_config_enabled:
if not inputs_config.get("preset_response"):
raise ValueError("inputs_config.preset_response is required")
if len(inputs_config.get("preset_response")) > 100:
raise ValueError("inputs_config.preset_response must be less than 100 characters")
if outputs_config_enabled:
if not outputs_config.get("preset_response"):
raise ValueError("outputs_config.preset_response is required")
if len(outputs_config.get("preset_response")) > 100:
raise ValueError("outputs_config.preset_response must be less than 100 characters")
class ModerationException(Exception):
pass

View File

@@ -0,0 +1,48 @@
from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory:
__extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None:
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param name: the name of extension
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
extension_class.validate_config(tenant_id, config)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
on the user inputs and return the processed results.
:param inputs: user inputs
:param query: query string (required in chat app)
:return:
"""
return self.__extension_instance.moderation_for_inputs(inputs, query)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
"""
Moderation for outputs.
When LLM outputs content, the front end will pass the output content (may be segmented)
to this method for sensitive content review, and the output content will be shielded if the review fails.
:param text: LLM output content
:return:
"""
return self.__extension_instance.moderation_for_outputs(text)

View File

@@ -0,0 +1 @@
2

View File

View File

@@ -0,0 +1,60 @@
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
cls._validate_inputs_and_outputs_config(config, True)
if not config.get("keywords"):
raise ValueError("keywords is required")
if len(config.get("keywords")) > 1000:
raise ValueError("keywords length must be less than 1000")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
preset_response = self.config['inputs_config']['preset_response']
if query:
inputs['query__'] = query
keywords_list = self.config['keywords'].split('\n')
flagged = self._is_violated(inputs, keywords_list)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
keywords_list = self.config['keywords'].split('\n')
flagged = self._is_violated({'text': text}, keywords_list)
preset_response = self.config['outputs_config']['preset_response']
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
for value in inputs.values():
if self._check_keywords_in_value(keywords_list, value):
return True
return False
def _check_keywords_in_value(self, keywords_list, value):
for keyword in keywords_list:
if keyword.lower() in value.lower():
return True
return False

View File

@@ -0,0 +1 @@
1

View File

@@ -0,0 +1,46 @@
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
from core.model_providers.model_factory import ModelFactory
class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
cls._validate_inputs_and_outputs_config(config, True)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
preset_response = self.config['inputs_config']['preset_response']
if query:
inputs['query__'] = query
flagged = self._is_violated(inputs)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
flagged = self._is_violated({'text': text})
preset_response = self.config['outputs_config']['preset_response']
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _is_violated(self, inputs: dict):
text = '\n'.join(inputs.values())
openai_moderation = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation")
is_not_invalid = openai_moderation.run(text)
return not is_not_invalid

View File

@@ -1,7 +1,5 @@
import math
from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
@@ -13,7 +11,6 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
@@ -27,7 +24,6 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser:
@@ -39,18 +35,20 @@ class OrchestratorRuleParser:
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
retriever_from: str = 'dev') -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict:
return None
agent_mode_config = self.app_model_config.agent_mode_dict
model_dict = self.app_model_config.model_dict
return_resource = self.app_model_config.retriever_resource_dict.get('enabled', False)
chain = None
if agent_mode_config and agent_mode_config.get('enabled'):
tool_configs = agent_mode_config.get('tools', [])
agent_provider_name = model_dict.get('provider', 'openai')
agent_model_name = model_dict.get('name', 'gpt-4')
dataset_configs = self.app_model_config.dataset_configs_dict
agent_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
@@ -77,7 +75,7 @@ class OrchestratorRuleParser:
# only OpenAI chat model (include Azure) support function call, use ReACT instead
if agent_model_instance.model_mode != ModelMode.CHAT \
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
if planning_strategy == PlanningStrategy.FUNCTION_CALL:
planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER
@@ -97,13 +95,14 @@ class OrchestratorRuleParser:
summary_model_instance = None
tools = self.to_tools(
agent_model_instance=agent_model_instance,
tool_configs=tool_configs,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
agent_model_instance=agent_model_instance,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
return_resource=return_resource,
retriever_from=retriever_from
retriever_from=retriever_from,
dataset_configs=dataset_configs
)
if len(tools) == 0:
@@ -125,66 +124,12 @@ class OrchestratorRuleParser:
return chain
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param model_instance: model instance
:param callbacks: callbacks for the chain
:param kwargs:
:return:
"""
sensitive_word_avoidance_rule = None
if self.app_model_config.sensitive_word_avoidance_dict:
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_config.get("enabled", False):
if sensitive_word_avoidance_config.get('type') == 'moderation':
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.MODERATION,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
)
else:
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_words:
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
extra_params={
'sensitive_words': sensitive_words.split(','),
}
)
if sensitive_word_avoidance_rule:
return SensitiveWordAvoidanceChain(
model_instance=model_instance,
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
output_key="sensitive_word_avoidance_output",
callbacks=callbacks,
**kwargs
)
return None
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
retriever_from: str = 'dev') -> list[BaseTool]:
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param agent_model_instance:
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
:param callbacks:
:param return_resource:
:param retriever_from:
:return:
"""
tools = []
@@ -196,29 +141,35 @@ class OrchestratorRuleParser:
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(agent_model_instance)
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search":
tool = self.to_google_search_tool()
tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
elif tool_type == "wikipedia":
tool = self.to_wikipedia_tool()
tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
elif tool_type == "current_datetime":
tool = self.to_current_datetime_tool()
tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
if tool:
tool.callbacks.extend(callbacks)
if tool.callbacks is not None:
tool.callbacks.extend(callbacks)
else:
tool.callbacks = callbacks
tools.append(tool)
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
dataset_configs: dict, rest_tokens: int,
return_resource: bool = False, retriever_from: str = 'dev',
**kwargs) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param dataset_configs:
:param conversation_message_task:
:param return_resource:
:param retriever_from:
@@ -236,10 +187,20 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
top_k = dataset_configs.get("top_k", 2)
# dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None
score_threshold_config = dataset_configs.get("score_threshold")
if score_threshold_config and score_threshold_config.get("enable"):
score_threshold = score_threshold_config.get("value")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
k=k,
top_k=top_k,
score_threshold=score_threshold,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
@@ -248,7 +209,7 @@ class OrchestratorRuleParser:
return tool
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
"""
A tool for reading web pages
@@ -269,15 +230,14 @@ class OrchestratorRuleParser:
summary_model_instance = None
tool = WebReaderTool(
llm=summary_model_instance.client if summary_model_instance else None,
model_instance=summary_model_instance if summary_model_instance else None,
max_chunk_length=4000,
continue_reading=True,
callbacks=[DifyStdOutCallbackHandler()]
continue_reading=True
)
return tool
def to_google_search_tool(self) -> Optional[BaseTool]:
def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
@@ -290,47 +250,39 @@ class OrchestratorRuleParser:
"is not up to date. "
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
args_schema=OptimizedSerpAPIInput,
callbacks=[DifyStdOutCallbackHandler()]
args_schema=OptimizedSerpAPIInput
)
return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]:
tool = DatetimeTool(
callbacks=[DifyStdOutCallbackHandler()]
)
def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool = DatetimeTool()
return tool
def to_wikipedia_tool(self) -> Optional[BaseTool]:
def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
return WikipediaQueryRun(
name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput,
callbacks=[DifyStdOutCallbackHandler()]
args_schema=WikipediaInput
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MAX_K = 10
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
if rest_tokens == -1:
return DEFAULT_K
return top_k
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
return top_k
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
return top_k
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
@@ -338,14 +290,7 @@ class OrchestratorRuleParser:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
if rest_tokens < segment_max_tokens * top_k:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
return min(context_limit_tokens // segment_max_tokens, MAX_K)
return min(top_k, 10)

Some files were not shown because too many files have changed in this diff Show More